Practice 2. Recurrent Neural Networks¶

  • Alejandro Dopico Castro (alejandro.dopico2@udc.es).
  • Ana Xiangning Pereira Ezquerro (ana.ezquerro@udc.es).

The following notebook contains execution examples of the recurrent neural architecture proposed for the Walmart dataset. The Python scripts submitted include auxiliar code to simplify the readibility of the code cells.

  • data.py: Includes the WalmartDataset class to instantiate each dataset.
  • model.py: Includes the WalmartModel class to instantiate a model with fixed hyperparameters and the DenormalizedMAE metric to use in the fit() Keras method.
  • plots.py: Includes auxiliary functions to display the time series performance of model predictions.
In [32]:
from data import * 
from plots import *
from model import *
from utils import *
from keras.layers import * 
from keras.models import Sequential, Model
from keras.optimizers import Adam, Optimizer, RMSprop
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.regularizers import L1, L2, L1L2
from tensorflow.data import Dataset
from itertools import product 
from collections import OrderedDict
import plotly.offline as pyo
pyo.init_notebook_mode()

Regularizer.__str__ = lambda x: str(x.__class__.__name__)
Optimizer.__str__ = lambda x: str(x.__class__.__name__) + f'({float(x.learning_rate.numpy()):1.0e})'


# global parameters 
TEST_RATIO = 0.2
VAL_RATIO = 0.15
BATCH_SIZE = 200

# load data 
data = WalmartDataset.load('Walmart.csv')
train, val, test = data.split(VAL_RATIO, TEST_RATIO)

Recurrent Neural Model¶

To model the temporal relations in the stream data, our neural architecture is a recurrent encoder ($\mathcal{E}$) with $\ell$ hidden layers of dimension $d_h$ that project the input sequence $\mathbf{X}\in\mathbb{R}^{S\times d_x}$ to a time-contextualized sequence of embeddings $\mathbf{H} = \mathcal{E}(\mathbf{X}) \in \mathbb{R}^{S\times d_h}$ (where $d_x$ and $d_h$ denote respectively the number of input features and the hidden dimension of the model and $S$ denote the sequence length). The result $\mathbf{H}$ is passed through a final recurrent layer (LSTM-based) and the final state $\tilde{\mathbf{h}}\in\mathbb{R}^{d_h}$ is used as a summarization of the sample. This representation is then passed to a feed-forward decoder composed of $\varphi$ dense layers, where the last one is constrained with a linear activation to predict the output value $\hat{y}$ (number of sales expected for the timestep $t+2$).

In this section we explored three possible values for the hyperparamenter $S$ to validate the impact of the past observation in the sales modelling, maintaining the other hyperparemeters (number of layers, model dimension, activations, etc.) with default values. The default configuration (baseline) uses an encoder of 2-stacked LSTMs with a decoder of 2 feed-forward networks. The only regularization method used is dropout (10%). This naive network can be easily improved, but we decided to start with the simplest architecture and incrementally increase the complexity of the model while controlling the overfitting with regularization methods.

In [3]:
# S = 2
model2 = WalmartModel(2, hidden_size=10)
model2.train(train, val, 'results/walmart2.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model2.evaluate(test)
Epoch 1/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 4s 45ms/step - dmae: 542466.0625 - loss: 1.2749 - mae: 0.9731 - val_dmae: 482828.5000 - val_loss: 1.1356 - val_mae: 0.8661
Epoch 2/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 529696.8750 - loss: 1.2212 - mae: 0.9502 - val_dmae: 451427.3438 - val_loss: 0.9896 - val_mae: 0.8098
Epoch 3/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 476533.1250 - loss: 1.0042 - mae: 0.8548 - val_dmae: 320770.8438 - val_loss: 0.4981 - val_mae: 0.5754
Epoch 4/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 283703.5625 - loss: 0.3883 - mae: 0.5089 - val_dmae: 196077.3125 - val_loss: 0.2602 - val_mae: 0.3517
Epoch 5/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 133294.7500 - loss: 0.1291 - mae: 0.2391 - val_dmae: 177371.0000 - val_loss: 0.2245 - val_mae: 0.3182
Epoch 6/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120702.5078 - loss: 0.1169 - mae: 0.2165 - val_dmae: 173767.5781 - val_loss: 0.2200 - val_mae: 0.3117
Epoch 7/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117085.4062 - loss: 0.1121 - mae: 0.2100 - val_dmae: 171468.9688 - val_loss: 0.2138 - val_mae: 0.3076
Epoch 8/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115684.9531 - loss: 0.1109 - mae: 0.2075 - val_dmae: 170553.6406 - val_loss: 0.2112 - val_mae: 0.3060
Epoch 9/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113892.6484 - loss: 0.1075 - mae: 0.2043 - val_dmae: 169441.3125 - val_loss: 0.2082 - val_mae: 0.3040
Epoch 10/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112650.3906 - loss: 0.1062 - mae: 0.2021 - val_dmae: 168496.5781 - val_loss: 0.2052 - val_mae: 0.3023
Epoch 11/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113185.1719 - loss: 0.1087 - mae: 0.2030 - val_dmae: 167545.3125 - val_loss: 0.2021 - val_mae: 0.3006
Epoch 12/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113233.1484 - loss: 0.1059 - mae: 0.2031 - val_dmae: 167308.2500 - val_loss: 0.1998 - val_mae: 0.3001
Epoch 13/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112262.5703 - loss: 0.1044 - mae: 0.2014 - val_dmae: 167006.6406 - val_loss: 0.1984 - val_mae: 0.2996
Epoch 14/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113321.8750 - loss: 0.1046 - mae: 0.2033 - val_dmae: 167198.8125 - val_loss: 0.1971 - val_mae: 0.2999
Epoch 15/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 113176.3359 - loss: 0.1042 - mae: 0.2030 - val_dmae: 166525.8594 - val_loss: 0.1947 - val_mae: 0.2987
Epoch 16/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112381.8125 - loss: 0.1026 - mae: 0.2016 - val_dmae: 165626.2031 - val_loss: 0.1924 - val_mae: 0.2971
Epoch 17/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112691.5000 - loss: 0.1029 - mae: 0.2022 - val_dmae: 165702.7500 - val_loss: 0.1905 - val_mae: 0.2973
Epoch 18/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 113656.5312 - loss: 0.1049 - mae: 0.2039 - val_dmae: 164918.7500 - val_loss: 0.1880 - val_mae: 0.2958
Epoch 19/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 113891.9297 - loss: 0.1029 - mae: 0.2043 - val_dmae: 164257.3125 - val_loss: 0.1861 - val_mae: 0.2947
Epoch 20/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113350.8906 - loss: 0.1020 - mae: 0.2033 - val_dmae: 163191.5312 - val_loss: 0.1831 - val_mae: 0.2927
Epoch 21/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 112985.3516 - loss: 0.1034 - mae: 0.2027 - val_dmae: 162745.6562 - val_loss: 0.1816 - val_mae: 0.2919
Epoch 22/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 113113.8594 - loss: 0.1030 - mae: 0.2029 - val_dmae: 162130.7500 - val_loss: 0.1797 - val_mae: 0.2908
Epoch 23/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112037.4297 - loss: 0.0988 - mae: 0.2010 - val_dmae: 160907.2188 - val_loss: 0.1766 - val_mae: 0.2887
Epoch 24/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 112930.2344 - loss: 0.1011 - mae: 0.2026 - val_dmae: 160469.9688 - val_loss: 0.1751 - val_mae: 0.2879
Epoch 25/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111846.2500 - loss: 0.0988 - mae: 0.2006 - val_dmae: 159643.8750 - val_loss: 0.1730 - val_mae: 0.2864
Epoch 26/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112309.9297 - loss: 0.0995 - mae: 0.2015 - val_dmae: 158771.9375 - val_loss: 0.1702 - val_mae: 0.2848
Epoch 27/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110020.6875 - loss: 0.0967 - mae: 0.1974 - val_dmae: 158217.5156 - val_loss: 0.1692 - val_mae: 0.2838
Epoch 28/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111687.5703 - loss: 0.0975 - mae: 0.2004 - val_dmae: 157435.3281 - val_loss: 0.1672 - val_mae: 0.2824
Epoch 29/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111681.6797 - loss: 0.0976 - mae: 0.2003 - val_dmae: 156246.0625 - val_loss: 0.1646 - val_mae: 0.2803
Epoch 30/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112036.5859 - loss: 0.1020 - mae: 0.2010 - val_dmae: 155778.1562 - val_loss: 0.1635 - val_mae: 0.2794
Epoch 31/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111349.8750 - loss: 0.0993 - mae: 0.1997 - val_dmae: 155304.1250 - val_loss: 0.1619 - val_mae: 0.2786
Epoch 32/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109372.3359 - loss: 0.0954 - mae: 0.1962 - val_dmae: 154837.9844 - val_loss: 0.1608 - val_mae: 0.2778
Epoch 33/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110124.2031 - loss: 0.0962 - mae: 0.1976 - val_dmae: 153661.8125 - val_loss: 0.1587 - val_mae: 0.2757
Epoch 34/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111761.6094 - loss: 0.0966 - mae: 0.2005 - val_dmae: 153093.3750 - val_loss: 0.1575 - val_mae: 0.2746
Epoch 35/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109890.2266 - loss: 0.0937 - mae: 0.1971 - val_dmae: 151940.3281 - val_loss: 0.1552 - val_mae: 0.2726
Epoch 36/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110978.9141 - loss: 0.0949 - mae: 0.1991 - val_dmae: 152149.7344 - val_loss: 0.1549 - val_mae: 0.2729
Epoch 37/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111047.1562 - loss: 0.0962 - mae: 0.1992 - val_dmae: 151276.1406 - val_loss: 0.1535 - val_mae: 0.2714
Epoch 38/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110416.0312 - loss: 0.0932 - mae: 0.1981 - val_dmae: 150854.2344 - val_loss: 0.1522 - val_mae: 0.2706
Epoch 39/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108360.7578 - loss: 0.0916 - mae: 0.1944 - val_dmae: 149624.4531 - val_loss: 0.1505 - val_mae: 0.2684
Epoch 40/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110382.8125 - loss: 0.0931 - mae: 0.1980 - val_dmae: 148468.9688 - val_loss: 0.1488 - val_mae: 0.2663
Epoch 41/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 108698.6172 - loss: 0.0894 - mae: 0.1950 - val_dmae: 148712.5781 - val_loss: 0.1482 - val_mae: 0.2668
Epoch 42/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109725.9766 - loss: 0.0910 - mae: 0.1968 - val_dmae: 148072.0312 - val_loss: 0.1471 - val_mae: 0.2656
Epoch 43/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108324.7344 - loss: 0.0889 - mae: 0.1943 - val_dmae: 147546.8750 - val_loss: 0.1466 - val_mae: 0.2647
Epoch 44/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109750.2891 - loss: 0.0924 - mae: 0.1969 - val_dmae: 147162.1562 - val_loss: 0.1459 - val_mae: 0.2640
Epoch 45/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108304.9375 - loss: 0.0898 - mae: 0.1943 - val_dmae: 146951.3281 - val_loss: 0.1453 - val_mae: 0.2636
Epoch 46/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107908.9453 - loss: 0.0876 - mae: 0.1936 - val_dmae: 146345.2031 - val_loss: 0.1441 - val_mae: 0.2625
Epoch 47/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 111194.6484 - loss: 0.0939 - mae: 0.1995 - val_dmae: 146156.3906 - val_loss: 0.1435 - val_mae: 0.2622
Epoch 48/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108984.6250 - loss: 0.0897 - mae: 0.1955 - val_dmae: 146511.0469 - val_loss: 0.1434 - val_mae: 0.2628
Epoch 49/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107806.2812 - loss: 0.0873 - mae: 0.1934 - val_dmae: 145242.2188 - val_loss: 0.1423 - val_mae: 0.2605
Epoch 50/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 106928.0391 - loss: 0.0884 - mae: 0.1918 - val_dmae: 145242.1406 - val_loss: 0.1418 - val_mae: 0.2605
Epoch 51/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 109188.1562 - loss: 0.0915 - mae: 0.1959 - val_dmae: 145059.5781 - val_loss: 0.1416 - val_mae: 0.2602
Epoch 52/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108148.0391 - loss: 0.0880 - mae: 0.1940 - val_dmae: 144448.9844 - val_loss: 0.1408 - val_mae: 0.2591
Epoch 53/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107990.7656 - loss: 0.0888 - mae: 0.1937 - val_dmae: 144357.6094 - val_loss: 0.1404 - val_mae: 0.2590
Epoch 54/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107526.2656 - loss: 0.0877 - mae: 0.1929 - val_dmae: 143896.8906 - val_loss: 0.1398 - val_mae: 0.2581
Epoch 55/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107085.8594 - loss: 0.0867 - mae: 0.1921 - val_dmae: 143877.1562 - val_loss: 0.1395 - val_mae: 0.2581
Epoch 56/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107904.9688 - loss: 0.0883 - mae: 0.1936 - val_dmae: 143210.7031 - val_loss: 0.1387 - val_mae: 0.2569
Epoch 57/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108230.6016 - loss: 0.0871 - mae: 0.1942 - val_dmae: 143528.7812 - val_loss: 0.1388 - val_mae: 0.2575
Epoch 58/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108813.1562 - loss: 0.0887 - mae: 0.1952 - val_dmae: 143870.9375 - val_loss: 0.1391 - val_mae: 0.2581
Epoch 59/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108596.9219 - loss: 0.0884 - mae: 0.1948 - val_dmae: 142628.6406 - val_loss: 0.1383 - val_mae: 0.2559
Epoch 60/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107074.8047 - loss: 0.0867 - mae: 0.1921 - val_dmae: 143124.1719 - val_loss: 0.1380 - val_mae: 0.2567
Epoch 61/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108068.4219 - loss: 0.0877 - mae: 0.1939 - val_dmae: 143261.3281 - val_loss: 0.1381 - val_mae: 0.2570
Epoch 62/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 3s 32ms/step - dmae: 107847.1641 - loss: 0.0868 - mae: 0.1935 - val_dmae: 142399.4844 - val_loss: 0.1370 - val_mae: 0.2554
Epoch 63/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108915.8359 - loss: 0.0876 - mae: 0.1954 - val_dmae: 142658.0625 - val_loss: 0.1369 - val_mae: 0.2559
Epoch 64/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106799.2969 - loss: 0.0860 - mae: 0.1916 - val_dmae: 142327.2188 - val_loss: 0.1363 - val_mae: 0.2553
Epoch 65/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105874.8125 - loss: 0.0833 - mae: 0.1899 - val_dmae: 142008.0625 - val_loss: 0.1363 - val_mae: 0.2547
Epoch 66/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108118.6641 - loss: 0.0855 - mae: 0.1940 - val_dmae: 141454.9531 - val_loss: 0.1356 - val_mae: 0.2538
Epoch 67/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109463.1953 - loss: 0.0893 - mae: 0.1964 - val_dmae: 141685.8750 - val_loss: 0.1352 - val_mae: 0.2542
Epoch 68/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105825.5234 - loss: 0.0858 - mae: 0.1898 - val_dmae: 140991.7812 - val_loss: 0.1351 - val_mae: 0.2529
Epoch 69/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106448.1719 - loss: 0.0838 - mae: 0.1910 - val_dmae: 140771.2031 - val_loss: 0.1346 - val_mae: 0.2525
Epoch 70/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104940.3516 - loss: 0.0825 - mae: 0.1883 - val_dmae: 141016.3750 - val_loss: 0.1346 - val_mae: 0.2530
Epoch 71/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106284.2344 - loss: 0.0846 - mae: 0.1907 - val_dmae: 140602.1562 - val_loss: 0.1338 - val_mae: 0.2522
Epoch 72/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106217.8594 - loss: 0.0820 - mae: 0.1905 - val_dmae: 140164.5625 - val_loss: 0.1336 - val_mae: 0.2514
Epoch 73/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105000.2891 - loss: 0.0818 - mae: 0.1884 - val_dmae: 140387.6406 - val_loss: 0.1332 - val_mae: 0.2518
Epoch 74/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106409.8359 - loss: 0.0834 - mae: 0.1909 - val_dmae: 139983.8594 - val_loss: 0.1331 - val_mae: 0.2511
Epoch 75/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106516.8438 - loss: 0.0852 - mae: 0.1911 - val_dmae: 139730.4844 - val_loss: 0.1326 - val_mae: 0.2507
Epoch 76/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105437.2344 - loss: 0.0820 - mae: 0.1891 - val_dmae: 139861.8281 - val_loss: 0.1328 - val_mae: 0.2509
Epoch 77/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106327.5234 - loss: 0.0846 - mae: 0.1907 - val_dmae: 139872.5938 - val_loss: 0.1324 - val_mae: 0.2509
Epoch 78/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106146.8438 - loss: 0.0838 - mae: 0.1904 - val_dmae: 139626.1875 - val_loss: 0.1321 - val_mae: 0.2505
Epoch 79/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107474.9922 - loss: 0.0847 - mae: 0.1928 - val_dmae: 139721.1719 - val_loss: 0.1324 - val_mae: 0.2506
Epoch 80/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107348.9219 - loss: 0.0847 - mae: 0.1926 - val_dmae: 139577.4062 - val_loss: 0.1320 - val_mae: 0.2504
Epoch 81/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106979.7734 - loss: 0.0847 - mae: 0.1919 - val_dmae: 138549.5312 - val_loss: 0.1311 - val_mae: 0.2485
Epoch 82/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105385.9531 - loss: 0.0832 - mae: 0.1891 - val_dmae: 138810.9219 - val_loss: 0.1309 - val_mae: 0.2490
Epoch 83/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105964.1875 - loss: 0.0833 - mae: 0.1901 - val_dmae: 139345.7812 - val_loss: 0.1312 - val_mae: 0.2500
Epoch 84/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104889.5781 - loss: 0.0820 - mae: 0.1882 - val_dmae: 138710.0625 - val_loss: 0.1308 - val_mae: 0.2488
Epoch 85/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 105695.7031 - loss: 0.0836 - mae: 0.1896 - val_dmae: 138363.7500 - val_loss: 0.1302 - val_mae: 0.2482
Epoch 86/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 105218.9688 - loss: 0.0839 - mae: 0.1888 - val_dmae: 139187.0781 - val_loss: 0.1305 - val_mae: 0.2497
Epoch 87/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105252.5781 - loss: 0.0844 - mae: 0.1888 - val_dmae: 138020.2031 - val_loss: 0.1299 - val_mae: 0.2476
Epoch 88/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104995.2734 - loss: 0.0826 - mae: 0.1884 - val_dmae: 137290.6875 - val_loss: 0.1292 - val_mae: 0.2463
Epoch 89/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105696.9922 - loss: 0.0826 - mae: 0.1896 - val_dmae: 138386.2031 - val_loss: 0.1299 - val_mae: 0.2483
Epoch 90/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 104927.4141 - loss: 0.0818 - mae: 0.1882 - val_dmae: 137169.7188 - val_loss: 0.1290 - val_mae: 0.2461
Epoch 91/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 104667.6094 - loss: 0.0822 - mae: 0.1878 - val_dmae: 137768.0625 - val_loss: 0.1292 - val_mae: 0.2471
Epoch 92/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103445.3047 - loss: 0.0821 - mae: 0.1856 - val_dmae: 137818.1250 - val_loss: 0.1293 - val_mae: 0.2472
Epoch 93/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103454.4453 - loss: 0.0787 - mae: 0.1856 - val_dmae: 137255.6719 - val_loss: 0.1284 - val_mae: 0.2462
Epoch 94/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104167.6328 - loss: 0.0809 - mae: 0.1869 - val_dmae: 136743.0469 - val_loss: 0.1280 - val_mae: 0.2453
Epoch 95/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103535.7812 - loss: 0.0788 - mae: 0.1857 - val_dmae: 136691.6406 - val_loss: 0.1279 - val_mae: 0.2452
Epoch 96/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104826.0703 - loss: 0.0823 - mae: 0.1880 - val_dmae: 136810.1406 - val_loss: 0.1279 - val_mae: 0.2454
Epoch 97/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106355.6016 - loss: 0.0829 - mae: 0.1908 - val_dmae: 136496.6406 - val_loss: 0.1273 - val_mae: 0.2449
Epoch 98/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105626.4219 - loss: 0.0839 - mae: 0.1895 - val_dmae: 136440.7812 - val_loss: 0.1276 - val_mae: 0.2448
Epoch 99/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104637.9688 - loss: 0.0823 - mae: 0.1877 - val_dmae: 136642.5625 - val_loss: 0.1270 - val_mae: 0.2451
Epoch 100/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104655.5000 - loss: 0.0811 - mae: 0.1877 - val_dmae: 136200.2188 - val_loss: 0.1271 - val_mae: 0.2443
Epoch 101/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104018.0781 - loss: 0.0807 - mae: 0.1866 - val_dmae: 136431.9844 - val_loss: 0.1267 - val_mae: 0.2447
Epoch 102/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104705.2578 - loss: 0.0799 - mae: 0.1878 - val_dmae: 135655.1406 - val_loss: 0.1261 - val_mae: 0.2434
Epoch 103/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104673.5547 - loss: 0.0800 - mae: 0.1878 - val_dmae: 136366.9375 - val_loss: 0.1266 - val_mae: 0.2446
Epoch 104/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102701.0391 - loss: 0.0790 - mae: 0.1842 - val_dmae: 135848.4688 - val_loss: 0.1261 - val_mae: 0.2437
Epoch 105/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105535.8281 - loss: 0.0829 - mae: 0.1893 - val_dmae: 136197.0156 - val_loss: 0.1263 - val_mae: 0.2443
Epoch 106/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102814.2266 - loss: 0.0789 - mae: 0.1844 - val_dmae: 136579.7188 - val_loss: 0.1269 - val_mae: 0.2450
Epoch 107/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104104.2266 - loss: 0.0808 - mae: 0.1868 - val_dmae: 134461.6562 - val_loss: 0.1247 - val_mae: 0.2412
Epoch 108/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102041.6641 - loss: 0.0780 - mae: 0.1831 - val_dmae: 135754.6875 - val_loss: 0.1254 - val_mae: 0.2435
Epoch 109/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103270.1797 - loss: 0.0788 - mae: 0.1853 - val_dmae: 135350.5625 - val_loss: 0.1254 - val_mae: 0.2428
Epoch 110/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104910.7344 - loss: 0.0816 - mae: 0.1882 - val_dmae: 134735.2188 - val_loss: 0.1251 - val_mae: 0.2417
Epoch 111/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103093.6016 - loss: 0.0793 - mae: 0.1849 - val_dmae: 135348.1875 - val_loss: 0.1248 - val_mae: 0.2428
Epoch 112/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 104613.5078 - loss: 0.0824 - mae: 0.1877 - val_dmae: 136088.3594 - val_loss: 0.1257 - val_mae: 0.2441
Epoch 112: early stopping
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 67069.3906 - loss: 0.0270 - mae: 0.1203
Out[3]:
[0.02606853097677231, 64549.47265625, 0.11579488962888718]
In [4]:
# S = 3
model3 = WalmartModel(3, hidden_size=10)
model3.train(train, val, 'results/walmart3.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model3.evaluate(test)
Epoch 1/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 4s 43ms/step - dmae: 540936.6250 - loss: 1.2712 - mae: 0.9704 - val_dmae: 446060.9375 - val_loss: 0.8851 - val_mae: 0.8002
Epoch 2/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 526711.0000 - loss: 1.2101 - mae: 0.9449 - val_dmae: 417605.8438 - val_loss: 0.7719 - val_mae: 0.7491
Epoch 3/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 472037.7812 - loss: 0.9910 - mae: 0.8468 - val_dmae: 290335.1875 - val_loss: 0.3701 - val_mae: 0.5208
Epoch 4/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 257934.1094 - loss: 0.3433 - mae: 0.4627 - val_dmae: 186003.2031 - val_loss: 0.2247 - val_mae: 0.3337
Epoch 5/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 132825.7969 - loss: 0.1373 - mae: 0.2383 - val_dmae: 171562.5156 - val_loss: 0.1824 - val_mae: 0.3078
Epoch 6/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 122590.3438 - loss: 0.1296 - mae: 0.2199 - val_dmae: 168244.2969 - val_loss: 0.1790 - val_mae: 0.3018
Epoch 7/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120699.1172 - loss: 0.1292 - mae: 0.2165 - val_dmae: 166717.6094 - val_loss: 0.1753 - val_mae: 0.2991
Epoch 8/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 119165.5938 - loss: 0.1258 - mae: 0.2138 - val_dmae: 165173.4531 - val_loss: 0.1717 - val_mae: 0.2963
Epoch 9/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 119174.4141 - loss: 0.1264 - mae: 0.2138 - val_dmae: 163359.1562 - val_loss: 0.1675 - val_mae: 0.2930
Epoch 10/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118432.2266 - loss: 0.1264 - mae: 0.2125 - val_dmae: 162832.9062 - val_loss: 0.1660 - val_mae: 0.2921
Epoch 11/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 117244.1484 - loss: 0.1241 - mae: 0.2103 - val_dmae: 162737.7969 - val_loss: 0.1639 - val_mae: 0.2919
Epoch 12/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120068.3359 - loss: 0.1260 - mae: 0.2154 - val_dmae: 162561.1719 - val_loss: 0.1632 - val_mae: 0.2916
Epoch 13/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117621.6562 - loss: 0.1246 - mae: 0.2110 - val_dmae: 161742.9531 - val_loss: 0.1609 - val_mae: 0.2901
Epoch 14/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118960.2812 - loss: 0.1226 - mae: 0.2134 - val_dmae: 161482.3906 - val_loss: 0.1603 - val_mae: 0.2897
Epoch 15/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118165.0078 - loss: 0.1215 - mae: 0.2120 - val_dmae: 160821.7812 - val_loss: 0.1583 - val_mae: 0.2885
Epoch 16/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117926.8594 - loss: 0.1220 - mae: 0.2115 - val_dmae: 160160.2031 - val_loss: 0.1579 - val_mae: 0.2873
Epoch 17/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 116689.5000 - loss: 0.1217 - mae: 0.2093 - val_dmae: 160171.0156 - val_loss: 0.1567 - val_mae: 0.2873
Epoch 18/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117721.7969 - loss: 0.1217 - mae: 0.2112 - val_dmae: 159014.6875 - val_loss: 0.1535 - val_mae: 0.2853
Epoch 19/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117070.3750 - loss: 0.1192 - mae: 0.2100 - val_dmae: 159033.2656 - val_loss: 0.1544 - val_mae: 0.2853
Epoch 20/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116089.2422 - loss: 0.1179 - mae: 0.2083 - val_dmae: 158147.5156 - val_loss: 0.1514 - val_mae: 0.2837
Epoch 21/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116260.5703 - loss: 0.1175 - mae: 0.2086 - val_dmae: 156836.6406 - val_loss: 0.1490 - val_mae: 0.2813
Epoch 22/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116733.7031 - loss: 0.1175 - mae: 0.2094 - val_dmae: 156020.1406 - val_loss: 0.1466 - val_mae: 0.2799
Epoch 23/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115904.0000 - loss: 0.1175 - mae: 0.2079 - val_dmae: 155950.4688 - val_loss: 0.1470 - val_mae: 0.2798
Epoch 24/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116191.7578 - loss: 0.1182 - mae: 0.2084 - val_dmae: 155160.9844 - val_loss: 0.1458 - val_mae: 0.2783
Epoch 25/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115599.6250 - loss: 0.1145 - mae: 0.2074 - val_dmae: 154370.4219 - val_loss: 0.1448 - val_mae: 0.2769
Epoch 26/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115037.5391 - loss: 0.1174 - mae: 0.2064 - val_dmae: 153504.0938 - val_loss: 0.1440 - val_mae: 0.2754
Epoch 27/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112945.7500 - loss: 0.1118 - mae: 0.2026 - val_dmae: 152736.9688 - val_loss: 0.1438 - val_mae: 0.2740
Epoch 28/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113147.0781 - loss: 0.1107 - mae: 0.2030 - val_dmae: 151582.6719 - val_loss: 0.1406 - val_mae: 0.2719
Epoch 29/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113631.4766 - loss: 0.1121 - mae: 0.2038 - val_dmae: 150585.4688 - val_loss: 0.1406 - val_mae: 0.2701
Epoch 30/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112733.0781 - loss: 0.1119 - mae: 0.2022 - val_dmae: 149879.8594 - val_loss: 0.1394 - val_mae: 0.2689
Epoch 31/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112611.1797 - loss: 0.1105 - mae: 0.2020 - val_dmae: 148606.7500 - val_loss: 0.1372 - val_mae: 0.2666
Epoch 32/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111950.7891 - loss: 0.1089 - mae: 0.2008 - val_dmae: 147362.0312 - val_loss: 0.1373 - val_mae: 0.2644
Epoch 33/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111432.6562 - loss: 0.1095 - mae: 0.1999 - val_dmae: 146865.7812 - val_loss: 0.1354 - val_mae: 0.2635
Epoch 34/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112428.4688 - loss: 0.1088 - mae: 0.2017 - val_dmae: 145913.6562 - val_loss: 0.1349 - val_mae: 0.2618
Epoch 35/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111081.5469 - loss: 0.1066 - mae: 0.1993 - val_dmae: 145309.6562 - val_loss: 0.1339 - val_mae: 0.2607
Epoch 36/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110126.4844 - loss: 0.1060 - mae: 0.1976 - val_dmae: 143942.8750 - val_loss: 0.1329 - val_mae: 0.2582
Epoch 37/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111330.8828 - loss: 0.1068 - mae: 0.1997 - val_dmae: 143556.3750 - val_loss: 0.1332 - val_mae: 0.2575
Epoch 38/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 110325.1953 - loss: 0.1053 - mae: 0.1979 - val_dmae: 141895.3438 - val_loss: 0.1299 - val_mae: 0.2545
Epoch 39/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 111159.8906 - loss: 0.1054 - mae: 0.1994 - val_dmae: 140494.4688 - val_loss: 0.1277 - val_mae: 0.2520
Epoch 40/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109931.8906 - loss: 0.1042 - mae: 0.1972 - val_dmae: 139570.3281 - val_loss: 0.1268 - val_mae: 0.2504
Epoch 41/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111023.9531 - loss: 0.1062 - mae: 0.1992 - val_dmae: 139635.0625 - val_loss: 0.1283 - val_mae: 0.2505
Epoch 42/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108282.9219 - loss: 0.1023 - mae: 0.1942 - val_dmae: 138786.0156 - val_loss: 0.1265 - val_mae: 0.2490
Epoch 43/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108149.6328 - loss: 0.1018 - mae: 0.1940 - val_dmae: 137035.6250 - val_loss: 0.1244 - val_mae: 0.2458
Epoch 44/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108169.9219 - loss: 0.1028 - mae: 0.1940 - val_dmae: 137295.5312 - val_loss: 0.1251 - val_mae: 0.2463
Epoch 45/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107773.6016 - loss: 0.1010 - mae: 0.1933 - val_dmae: 135757.5000 - val_loss: 0.1234 - val_mae: 0.2435
Epoch 46/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107828.5312 - loss: 0.1023 - mae: 0.1934 - val_dmae: 135303.0625 - val_loss: 0.1226 - val_mae: 0.2427
Epoch 47/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107727.1484 - loss: 0.1017 - mae: 0.1933 - val_dmae: 135700.1406 - val_loss: 0.1245 - val_mae: 0.2434
Epoch 48/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108498.3203 - loss: 0.1000 - mae: 0.1946 - val_dmae: 134909.6875 - val_loss: 0.1239 - val_mae: 0.2420
Epoch 49/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107736.9609 - loss: 0.0997 - mae: 0.1933 - val_dmae: 133203.0312 - val_loss: 0.1201 - val_mae: 0.2390
Epoch 50/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108150.8750 - loss: 0.0993 - mae: 0.1940 - val_dmae: 133000.3594 - val_loss: 0.1191 - val_mae: 0.2386
Epoch 51/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107397.0469 - loss: 0.0992 - mae: 0.1927 - val_dmae: 132878.0625 - val_loss: 0.1196 - val_mae: 0.2384
Epoch 52/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109820.4297 - loss: 0.1019 - mae: 0.1970 - val_dmae: 131335.3594 - val_loss: 0.1181 - val_mae: 0.2356
Epoch 53/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107559.5000 - loss: 0.0980 - mae: 0.1930 - val_dmae: 132667.8906 - val_loss: 0.1199 - val_mae: 0.2380
Epoch 54/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106483.0859 - loss: 0.0995 - mae: 0.1910 - val_dmae: 131244.6406 - val_loss: 0.1177 - val_mae: 0.2354
Epoch 55/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107027.5000 - loss: 0.0969 - mae: 0.1920 - val_dmae: 130604.6719 - val_loss: 0.1171 - val_mae: 0.2343
Epoch 56/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106827.5938 - loss: 0.0976 - mae: 0.1916 - val_dmae: 131395.9531 - val_loss: 0.1184 - val_mae: 0.2357
Epoch 57/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107072.7266 - loss: 0.0979 - mae: 0.1921 - val_dmae: 129069.6094 - val_loss: 0.1142 - val_mae: 0.2315
Epoch 58/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107479.7734 - loss: 0.0988 - mae: 0.1928 - val_dmae: 129324.2578 - val_loss: 0.1153 - val_mae: 0.2320
Epoch 59/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104877.4062 - loss: 0.0967 - mae: 0.1881 - val_dmae: 130815.4141 - val_loss: 0.1172 - val_mae: 0.2347
Epoch 60/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106237.4375 - loss: 0.0994 - mae: 0.1906 - val_dmae: 128380.0391 - val_loss: 0.1135 - val_mae: 0.2303
Epoch 61/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105664.3203 - loss: 0.0957 - mae: 0.1896 - val_dmae: 128674.1875 - val_loss: 0.1139 - val_mae: 0.2308
Epoch 62/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106641.8828 - loss: 0.0975 - mae: 0.1913 - val_dmae: 127305.0312 - val_loss: 0.1122 - val_mae: 0.2284
Epoch 63/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107236.2109 - loss: 0.0992 - mae: 0.1924 - val_dmae: 127473.6016 - val_loss: 0.1116 - val_mae: 0.2287
Epoch 64/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106514.9141 - loss: 0.0961 - mae: 0.1911 - val_dmae: 127776.3828 - val_loss: 0.1119 - val_mae: 0.2292
Epoch 65/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106053.3594 - loss: 0.0958 - mae: 0.1902 - val_dmae: 126682.3516 - val_loss: 0.1109 - val_mae: 0.2273
Epoch 66/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 105466.3984 - loss: 0.0949 - mae: 0.1892 - val_dmae: 126861.6484 - val_loss: 0.1102 - val_mae: 0.2276
Epoch 67/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 107242.6875 - loss: 0.0959 - mae: 0.1924 - val_dmae: 127207.3672 - val_loss: 0.1109 - val_mae: 0.2282
Epoch 68/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107251.6328 - loss: 0.0957 - mae: 0.1924 - val_dmae: 125882.2734 - val_loss: 0.1085 - val_mae: 0.2258
Epoch 69/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107183.1797 - loss: 0.0940 - mae: 0.1923 - val_dmae: 126138.2188 - val_loss: 0.1083 - val_mae: 0.2263
Epoch 70/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106752.0781 - loss: 0.0926 - mae: 0.1915 - val_dmae: 126654.4922 - val_loss: 0.1091 - val_mae: 0.2272
Epoch 71/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105356.8359 - loss: 0.0945 - mae: 0.1890 - val_dmae: 126322.1719 - val_loss: 0.1086 - val_mae: 0.2266
Epoch 72/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105719.0391 - loss: 0.0933 - mae: 0.1896 - val_dmae: 125724.5234 - val_loss: 0.1082 - val_mae: 0.2255
Epoch 73/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104421.3203 - loss: 0.0906 - mae: 0.1873 - val_dmae: 124994.4219 - val_loss: 0.1073 - val_mae: 0.2242
Epoch 74/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103898.9297 - loss: 0.0898 - mae: 0.1864 - val_dmae: 126199.6406 - val_loss: 0.1087 - val_mae: 0.2264
Epoch 75/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105222.2500 - loss: 0.0911 - mae: 0.1888 - val_dmae: 125876.5703 - val_loss: 0.1080 - val_mae: 0.2258
Epoch 76/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104338.0625 - loss: 0.0906 - mae: 0.1872 - val_dmae: 126101.1406 - val_loss: 0.1084 - val_mae: 0.2262
Epoch 77/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 102422.9766 - loss: 0.0876 - mae: 0.1837 - val_dmae: 124594.4141 - val_loss: 0.1065 - val_mae: 0.2235
Epoch 78/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103584.4844 - loss: 0.0885 - mae: 0.1858 - val_dmae: 124104.4922 - val_loss: 0.1053 - val_mae: 0.2226
Epoch 79/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104298.5938 - loss: 0.0881 - mae: 0.1871 - val_dmae: 124902.7422 - val_loss: 0.1058 - val_mae: 0.2241
Epoch 80/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103117.6797 - loss: 0.0886 - mae: 0.1850 - val_dmae: 124353.8672 - val_loss: 0.1056 - val_mae: 0.2231
Epoch 81/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104999.2266 - loss: 0.0891 - mae: 0.1884 - val_dmae: 124839.6094 - val_loss: 0.1061 - val_mae: 0.2239
Epoch 82/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 102493.0156 - loss: 0.0851 - mae: 0.1839 - val_dmae: 125998.2734 - val_loss: 0.1073 - val_mae: 0.2260
Epoch 83/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103112.1172 - loss: 0.0856 - mae: 0.1850 - val_dmae: 124823.8359 - val_loss: 0.1056 - val_mae: 0.2239
Epoch 83: early stopping
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 72043.2344 - loss: 0.0310 - mae: 0.1292
Out[4]:
[0.02758491039276123, 66173.328125, 0.11870791763067245]
In [5]:
# S = 4
model4 = WalmartModel(4, hidden_size=10)
model4.train(train, val, 'results/walmart3.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model4.evaluate(test)
Epoch 1/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 4s 44ms/step - dmae: 541183.5000 - loss: 1.2739 - mae: 0.9708 - val_dmae: 445947.1250 - val_loss: 0.8849 - val_mae: 0.8000
Epoch 2/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 523680.4062 - loss: 1.1994 - mae: 0.9394 - val_dmae: 389234.1562 - val_loss: 0.6675 - val_mae: 0.6982
Epoch 3/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 410993.5625 - loss: 0.7863 - mae: 0.7373 - val_dmae: 204479.6875 - val_loss: 0.2104 - val_mae: 0.3668
Epoch 4/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 140061.5156 - loss: 0.1481 - mae: 0.2513 - val_dmae: 179813.2812 - val_loss: 0.1822 - val_mae: 0.3226
Epoch 5/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 123100.4844 - loss: 0.1321 - mae: 0.2208 - val_dmae: 177829.4844 - val_loss: 0.1793 - val_mae: 0.3190
Epoch 6/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 121718.7812 - loss: 0.1309 - mae: 0.2184 - val_dmae: 176610.0781 - val_loss: 0.1757 - val_mae: 0.3168
Epoch 7/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 122855.4531 - loss: 0.1315 - mae: 0.2204 - val_dmae: 175641.4844 - val_loss: 0.1735 - val_mae: 0.3151
Epoch 8/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 121785.8672 - loss: 0.1301 - mae: 0.2185 - val_dmae: 175341.3594 - val_loss: 0.1713 - val_mae: 0.3145
Epoch 9/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120890.7188 - loss: 0.1289 - mae: 0.2169 - val_dmae: 174643.0781 - val_loss: 0.1695 - val_mae: 0.3133
Epoch 10/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 121292.8984 - loss: 0.1299 - mae: 0.2176 - val_dmae: 174020.7812 - val_loss: 0.1675 - val_mae: 0.3122
Epoch 11/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 120828.5781 - loss: 0.1299 - mae: 0.2168 - val_dmae: 173717.3281 - val_loss: 0.1668 - val_mae: 0.3116
Epoch 12/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120860.0391 - loss: 0.1286 - mae: 0.2168 - val_dmae: 173037.5781 - val_loss: 0.1651 - val_mae: 0.3104
Epoch 13/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118227.3750 - loss: 0.1263 - mae: 0.2121 - val_dmae: 172579.0781 - val_loss: 0.1642 - val_mae: 0.3096
Epoch 14/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120309.1562 - loss: 0.1275 - mae: 0.2158 - val_dmae: 172116.3594 - val_loss: 0.1632 - val_mae: 0.3088
Epoch 15/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 119245.2422 - loss: 0.1274 - mae: 0.2139 - val_dmae: 171584.2969 - val_loss: 0.1626 - val_mae: 0.3078
Epoch 16/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 119667.2344 - loss: 0.1261 - mae: 0.2147 - val_dmae: 170852.2812 - val_loss: 0.1609 - val_mae: 0.3065
Epoch 17/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118587.4531 - loss: 0.1247 - mae: 0.2127 - val_dmae: 170126.2031 - val_loss: 0.1599 - val_mae: 0.3052
Epoch 18/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 119761.6016 - loss: 0.1256 - mae: 0.2148 - val_dmae: 169628.2500 - val_loss: 0.1588 - val_mae: 0.3043
Epoch 19/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118999.5859 - loss: 0.1254 - mae: 0.2135 - val_dmae: 169241.3125 - val_loss: 0.1578 - val_mae: 0.3036
Epoch 20/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 120278.4688 - loss: 0.1257 - mae: 0.2158 - val_dmae: 168276.1250 - val_loss: 0.1565 - val_mae: 0.3019
Epoch 21/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 119082.6094 - loss: 0.1239 - mae: 0.2136 - val_dmae: 167657.5312 - val_loss: 0.1552 - val_mae: 0.3008
Epoch 22/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118829.9531 - loss: 0.1235 - mae: 0.2132 - val_dmae: 166965.6875 - val_loss: 0.1547 - val_mae: 0.2995
Epoch 23/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117828.6094 - loss: 0.1223 - mae: 0.2114 - val_dmae: 166416.9531 - val_loss: 0.1534 - val_mae: 0.2985
Epoch 24/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118497.5547 - loss: 0.1225 - mae: 0.2126 - val_dmae: 165379.1250 - val_loss: 0.1522 - val_mae: 0.2967
Epoch 25/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117032.6328 - loss: 0.1208 - mae: 0.2099 - val_dmae: 164990.1406 - val_loss: 0.1516 - val_mae: 0.2960
Epoch 26/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118040.7656 - loss: 0.1230 - mae: 0.2118 - val_dmae: 163983.4375 - val_loss: 0.1509 - val_mae: 0.2942
Epoch 27/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115586.5312 - loss: 0.1188 - mae: 0.2073 - val_dmae: 163506.5000 - val_loss: 0.1496 - val_mae: 0.2933
Epoch 28/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116925.1016 - loss: 0.1226 - mae: 0.2098 - val_dmae: 162691.6562 - val_loss: 0.1489 - val_mae: 0.2919
Epoch 29/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117944.3047 - loss: 0.1194 - mae: 0.2116 - val_dmae: 161717.6562 - val_loss: 0.1479 - val_mae: 0.2901
Epoch 30/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116896.6641 - loss: 0.1212 - mae: 0.2097 - val_dmae: 160816.7969 - val_loss: 0.1455 - val_mae: 0.2885
Epoch 31/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116105.2422 - loss: 0.1179 - mae: 0.2083 - val_dmae: 160315.1094 - val_loss: 0.1462 - val_mae: 0.2876
Epoch 32/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116258.7734 - loss: 0.1181 - mae: 0.2086 - val_dmae: 159055.1719 - val_loss: 0.1445 - val_mae: 0.2853
Epoch 33/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115883.6797 - loss: 0.1173 - mae: 0.2079 - val_dmae: 158377.6094 - val_loss: 0.1431 - val_mae: 0.2841
Epoch 34/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 115552.8828 - loss: 0.1167 - mae: 0.2073 - val_dmae: 157624.4062 - val_loss: 0.1428 - val_mae: 0.2828
Epoch 35/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114920.1016 - loss: 0.1162 - mae: 0.2062 - val_dmae: 156869.3438 - val_loss: 0.1410 - val_mae: 0.2814
Epoch 36/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 115326.1328 - loss: 0.1162 - mae: 0.2069 - val_dmae: 155840.2656 - val_loss: 0.1408 - val_mae: 0.2796
Epoch 37/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114686.2031 - loss: 0.1160 - mae: 0.2057 - val_dmae: 154683.8438 - val_loss: 0.1384 - val_mae: 0.2775
Epoch 38/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115514.4844 - loss: 0.1158 - mae: 0.2072 - val_dmae: 153769.5781 - val_loss: 0.1385 - val_mae: 0.2758
Epoch 39/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114226.4062 - loss: 0.1153 - mae: 0.2049 - val_dmae: 153012.1719 - val_loss: 0.1362 - val_mae: 0.2745
Epoch 40/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113979.2031 - loss: 0.1135 - mae: 0.2045 - val_dmae: 152202.2812 - val_loss: 0.1354 - val_mae: 0.2730
Epoch 41/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113728.9062 - loss: 0.1115 - mae: 0.2040 - val_dmae: 151430.1094 - val_loss: 0.1353 - val_mae: 0.2716
Epoch 42/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 113138.7812 - loss: 0.1108 - mae: 0.2030 - val_dmae: 150110.1406 - val_loss: 0.1350 - val_mae: 0.2693
Epoch 43/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112732.0078 - loss: 0.1121 - mae: 0.2022 - val_dmae: 149602.0312 - val_loss: 0.1334 - val_mae: 0.2684
Epoch 44/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112953.9062 - loss: 0.1103 - mae: 0.2026 - val_dmae: 148456.2188 - val_loss: 0.1319 - val_mae: 0.2663
Epoch 45/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112995.1875 - loss: 0.1098 - mae: 0.2027 - val_dmae: 147468.7188 - val_loss: 0.1319 - val_mae: 0.2645
Epoch 46/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112818.1875 - loss: 0.1119 - mae: 0.2024 - val_dmae: 146326.2031 - val_loss: 0.1303 - val_mae: 0.2625
Epoch 47/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111875.6641 - loss: 0.1095 - mae: 0.2007 - val_dmae: 145698.6250 - val_loss: 0.1292 - val_mae: 0.2614
Epoch 48/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110506.3906 - loss: 0.1053 - mae: 0.1982 - val_dmae: 144169.7656 - val_loss: 0.1285 - val_mae: 0.2586
Epoch 49/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112231.7969 - loss: 0.1075 - mae: 0.2013 - val_dmae: 143225.8281 - val_loss: 0.1270 - val_mae: 0.2569
Epoch 50/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110201.1484 - loss: 0.1069 - mae: 0.1977 - val_dmae: 142438.1719 - val_loss: 0.1259 - val_mae: 0.2555
Epoch 51/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 110309.5938 - loss: 0.1065 - mae: 0.1979 - val_dmae: 141394.1406 - val_loss: 0.1253 - val_mae: 0.2536
Epoch 52/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108985.8203 - loss: 0.1044 - mae: 0.1955 - val_dmae: 140468.4688 - val_loss: 0.1243 - val_mae: 0.2520
Epoch 53/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111539.7656 - loss: 0.1061 - mae: 0.2001 - val_dmae: 139558.6094 - val_loss: 0.1235 - val_mae: 0.2504
Epoch 54/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109181.4688 - loss: 0.1025 - mae: 0.1959 - val_dmae: 138296.8281 - val_loss: 0.1217 - val_mae: 0.2481
Epoch 55/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109968.4609 - loss: 0.1032 - mae: 0.1973 - val_dmae: 137570.2344 - val_loss: 0.1217 - val_mae: 0.2468
Epoch 56/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 107895.3750 - loss: 0.0994 - mae: 0.1936 - val_dmae: 136646.6875 - val_loss: 0.1208 - val_mae: 0.2451
Epoch 57/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107892.7422 - loss: 0.1010 - mae: 0.1935 - val_dmae: 136022.4531 - val_loss: 0.1198 - val_mae: 0.2440
Epoch 58/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108493.5859 - loss: 0.1007 - mae: 0.1946 - val_dmae: 134835.5938 - val_loss: 0.1189 - val_mae: 0.2419
Epoch 59/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109186.6172 - loss: 0.1018 - mae: 0.1959 - val_dmae: 134169.0625 - val_loss: 0.1181 - val_mae: 0.2407
Epoch 60/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109062.6641 - loss: 0.1026 - mae: 0.1956 - val_dmae: 133646.3125 - val_loss: 0.1180 - val_mae: 0.2397
Epoch 61/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107113.2734 - loss: 0.0979 - mae: 0.1921 - val_dmae: 132571.2188 - val_loss: 0.1162 - val_mae: 0.2378
Epoch 62/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105384.2266 - loss: 0.0981 - mae: 0.1890 - val_dmae: 131826.2031 - val_loss: 0.1162 - val_mae: 0.2365
Epoch 63/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106014.0703 - loss: 0.0974 - mae: 0.1902 - val_dmae: 131224.7031 - val_loss: 0.1149 - val_mae: 0.2354
Epoch 64/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106483.6250 - loss: 0.0976 - mae: 0.1910 - val_dmae: 130980.0000 - val_loss: 0.1144 - val_mae: 0.2350
Epoch 65/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 105500.1406 - loss: 0.0961 - mae: 0.1893 - val_dmae: 129380.7969 - val_loss: 0.1125 - val_mae: 0.2321
Epoch 66/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107148.6094 - loss: 0.0964 - mae: 0.1922 - val_dmae: 129163.6406 - val_loss: 0.1125 - val_mae: 0.2317
Epoch 67/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106408.5000 - loss: 0.0970 - mae: 0.1909 - val_dmae: 129165.8047 - val_loss: 0.1120 - val_mae: 0.2317
Epoch 68/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104381.0312 - loss: 0.0939 - mae: 0.1872 - val_dmae: 127756.3828 - val_loss: 0.1108 - val_mae: 0.2292
Epoch 69/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105224.0547 - loss: 0.0935 - mae: 0.1888 - val_dmae: 126828.8438 - val_loss: 0.1095 - val_mae: 0.2275
Epoch 70/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104486.1484 - loss: 0.0930 - mae: 0.1874 - val_dmae: 126969.3359 - val_loss: 0.1094 - val_mae: 0.2278
Epoch 71/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 103941.8906 - loss: 0.0921 - mae: 0.1865 - val_dmae: 126879.2422 - val_loss: 0.1091 - val_mae: 0.2276
Epoch 72/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105350.7734 - loss: 0.0933 - mae: 0.1890 - val_dmae: 126084.0000 - val_loss: 0.1085 - val_mae: 0.2262
Epoch 73/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104282.8750 - loss: 0.0908 - mae: 0.1871 - val_dmae: 126411.7500 - val_loss: 0.1079 - val_mae: 0.2268
Epoch 74/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 104767.8828 - loss: 0.0938 - mae: 0.1879 - val_dmae: 124989.8672 - val_loss: 0.1067 - val_mae: 0.2242
Epoch 75/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104791.2188 - loss: 0.0916 - mae: 0.1880 - val_dmae: 124792.5938 - val_loss: 0.1059 - val_mae: 0.2239
Epoch 76/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104388.0078 - loss: 0.0911 - mae: 0.1873 - val_dmae: 125464.5000 - val_loss: 0.1064 - val_mae: 0.2251
Epoch 77/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104210.6797 - loss: 0.0914 - mae: 0.1869 - val_dmae: 126268.2969 - val_loss: 0.1070 - val_mae: 0.2265
Epoch 78/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103442.5859 - loss: 0.0917 - mae: 0.1856 - val_dmae: 125105.3750 - val_loss: 0.1058 - val_mae: 0.2244
Epoch 79/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103514.8906 - loss: 0.0897 - mae: 0.1857 - val_dmae: 123832.4219 - val_loss: 0.1045 - val_mae: 0.2221
Epoch 80/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103531.8594 - loss: 0.0920 - mae: 0.1857 - val_dmae: 123028.3438 - val_loss: 0.1036 - val_mae: 0.2207
Epoch 81/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103358.8594 - loss: 0.0879 - mae: 0.1854 - val_dmae: 123414.1641 - val_loss: 0.1035 - val_mae: 0.2214
Epoch 82/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103729.3047 - loss: 0.0902 - mae: 0.1861 - val_dmae: 122685.9375 - val_loss: 0.1031 - val_mae: 0.2201
Epoch 83/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103041.3906 - loss: 0.0902 - mae: 0.1848 - val_dmae: 123704.9766 - val_loss: 0.1038 - val_mae: 0.2219
Epoch 84/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102791.0625 - loss: 0.0872 - mae: 0.1844 - val_dmae: 125045.1719 - val_loss: 0.1050 - val_mae: 0.2243
Epoch 85/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103251.6406 - loss: 0.0866 - mae: 0.1852 - val_dmae: 124456.6250 - val_loss: 0.1044 - val_mae: 0.2233
Epoch 86/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 101863.6719 - loss: 0.0857 - mae: 0.1827 - val_dmae: 123422.0469 - val_loss: 0.1035 - val_mae: 0.2214
Epoch 87/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104849.0078 - loss: 0.0879 - mae: 0.1881 - val_dmae: 124842.6250 - val_loss: 0.1044 - val_mae: 0.2240
Epoch 87: early stopping
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 78321.1094 - loss: 0.0377 - mae: 0.1405
Out[5]:
[0.03053026646375656, 68176.4375, 0.12230127304792404]

The baseline network achieves 64k, 66k and 68k with a sequence length $S$ of 2, 3 and 4, respectively. Theoretically, having more information in the input of the model (e.g. the model with $S=4$ has more information than the first model with $S=2$) should retrieve at least the same results than simpler input representations. In practice, optimizing networks with a lot of input noise (data that has no information to predict the target) with no regularization techniques is extremely hard and would require a considerable amount of data to ensure the generalization of the model. In our case, from the results obtained with the baseline model, it seems that introducing more than $S=2$ past information does not help the network to predict the target outcome, so it should be enough to use sequences of length 2. For the next models we fixed this hyperparameter to $S=2$.

In [34]:
plot_series(model2, [train, val, test], title='Prediction with S=2').show()
plot_series(model3, [train, val, test], title='Prediction with S=3').show()
plot_series(model4, [train, val, test], title='Prediction with S=4').show()
2024-04-11 13:10:35.688926: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

To better show the performance disparity between different models we printed the absolute error of the three models in a single plot. Note that the model with $S=2$ gets a higher error (specially in outlier observations) and the model with $S=4$ is the closest to the zero-line, indicating a lower absolute error.

In [10]:
plot_errors([model2, model3, model4], [train, val, test])
2024-04-11 12:14:35.759970: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 12:14:41.192248: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 12:14:46.287077: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 12:14:51.880613: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 12:14:56.938031: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 12:15:02.396462: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 12:15:07.917016: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 12:15:13.039258: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 12:15:18.518971: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

Increasing the complexity of the model¶

Once we got our baseline results, we increased the complexity of the model by modifying the hyperparameters of the network (see model.py).

  • base_layer: The recurrent base cell of the encoder $\mathcal{E}$. There are available two options: the LSTM and the GRU. Althought the LSTM layer is more popular than GRU (specifically in NLP tasks), the GRU has its advantages over the LSTM (e.g. it has less parameters) and it is still used in other DL applications. By default, each base layer is a LSTM.
  • num_encoder_layers ($\ell$): Number of layers in the encoder $\mathcal{E}$.
  • num_decoder_layers ($\varphi$): Number of layer in the decoder.
  • hidden_size ($d_h$): Hidden dimension of the encoder $\mathcal{E}$.
  • regularizer: Kernel and bias regularizer in the hidden layers. By default there is no regularization.
  • initializer: Weight initialization. All biases are initialized from zero. By default, kernels are initialized following a random normal distribution.
  • bidirectional: Whether to process left-to-right and right-to-left the input sequence or only left-to-right. By default, the processing is bidirectional, so the left-to-right and righ-to-left information is concatenated to produce a unique contextualization of each timestep observation.
  • dropout: Dropout value in the latent space of the neural architecture.

The next cell code increaes the dimensionality of the model upon $d_h=50$ and uses 3 layers in the encoder (maintaining 2 layers in the decoder). We tested two different architectures: the first one uses the LSTM cell again and the second replaces the LSTM layer by the GRU cell.

In [15]:
model_lstm = WalmartModel(2, hidden_size=50, num_encoder_layers=3)
model_lstm.train(train, val, 'results/walmart3.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model_lstm.evaluate(test)
Epoch 1/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 5s 47ms/step - dmae: 540113.5000 - loss: 1.2611 - mae: 0.9689 - val_dmae: 403419.7812 - val_loss: 0.8032 - val_mae: 0.7237
Epoch 2/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 319207.9688 - loss: 0.5448 - mae: 0.5726 - val_dmae: 180465.2656 - val_loss: 0.2324 - val_mae: 0.3237
Epoch 3/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 122478.0469 - loss: 0.1243 - mae: 0.2197 - val_dmae: 167313.8438 - val_loss: 0.2176 - val_mae: 0.3001
Epoch 4/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113633.9141 - loss: 0.1122 - mae: 0.2038 - val_dmae: 167765.9531 - val_loss: 0.2118 - val_mae: 0.3010
Epoch 5/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114383.8281 - loss: 0.1143 - mae: 0.2052 - val_dmae: 167666.5625 - val_loss: 0.2088 - val_mae: 0.3008
Epoch 6/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116631.1875 - loss: 0.1146 - mae: 0.2092 - val_dmae: 167718.7188 - val_loss: 0.2064 - val_mae: 0.3009
Epoch 7/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 117350.9375 - loss: 0.1155 - mae: 0.2105 - val_dmae: 167279.7188 - val_loss: 0.2037 - val_mae: 0.3001
Epoch 8/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117495.7812 - loss: 0.1139 - mae: 0.2108 - val_dmae: 166875.0312 - val_loss: 0.2002 - val_mae: 0.2994
Epoch 9/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 118276.0234 - loss: 0.1144 - mae: 0.2122 - val_dmae: 166421.4531 - val_loss: 0.1969 - val_mae: 0.2985
Epoch 10/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 115058.5391 - loss: 0.1087 - mae: 0.2064 - val_dmae: 165433.7344 - val_loss: 0.1922 - val_mae: 0.2968
Epoch 11/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116101.3281 - loss: 0.1096 - mae: 0.2083 - val_dmae: 164310.3594 - val_loss: 0.1869 - val_mae: 0.2948
Epoch 12/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 113685.2656 - loss: 0.1060 - mae: 0.2039 - val_dmae: 163569.2969 - val_loss: 0.1814 - val_mae: 0.2934
Epoch 13/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113298.7344 - loss: 0.1045 - mae: 0.2032 - val_dmae: 161447.1406 - val_loss: 0.1744 - val_mae: 0.2896
Epoch 14/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111946.1172 - loss: 0.1011 - mae: 0.2008 - val_dmae: 160629.4844 - val_loss: 0.1695 - val_mae: 0.2882
Epoch 15/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111412.2891 - loss: 0.0998 - mae: 0.1999 - val_dmae: 158088.5625 - val_loss: 0.1633 - val_mae: 0.2836
Epoch 16/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108755.0547 - loss: 0.0949 - mae: 0.1951 - val_dmae: 156150.9219 - val_loss: 0.1589 - val_mae: 0.2801
Epoch 17/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108034.8281 - loss: 0.0926 - mae: 0.1938 - val_dmae: 153489.8281 - val_loss: 0.1534 - val_mae: 0.2753
Epoch 18/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 107057.3828 - loss: 0.0911 - mae: 0.1920 - val_dmae: 151001.2812 - val_loss: 0.1488 - val_mae: 0.2709
Epoch 19/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104880.9609 - loss: 0.0879 - mae: 0.1881 - val_dmae: 149103.7500 - val_loss: 0.1448 - val_mae: 0.2675
Epoch 20/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 105753.2891 - loss: 0.0872 - mae: 0.1897 - val_dmae: 147965.2812 - val_loss: 0.1424 - val_mae: 0.2654
Epoch 21/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 104211.5781 - loss: 0.0860 - mae: 0.1869 - val_dmae: 145217.1250 - val_loss: 0.1390 - val_mae: 0.2605
Epoch 22/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103770.3359 - loss: 0.0847 - mae: 0.1862 - val_dmae: 144004.5156 - val_loss: 0.1368 - val_mae: 0.2583
Epoch 23/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102076.5391 - loss: 0.0813 - mae: 0.1831 - val_dmae: 142253.1250 - val_loss: 0.1340 - val_mae: 0.2552
Epoch 24/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 102121.5234 - loss: 0.0813 - mae: 0.1832 - val_dmae: 140581.1406 - val_loss: 0.1321 - val_mae: 0.2522
Epoch 25/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 101737.6406 - loss: 0.0808 - mae: 0.1825 - val_dmae: 140625.4844 - val_loss: 0.1312 - val_mae: 0.2523
Epoch 26/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 99911.0547 - loss: 0.0784 - mae: 0.1792 - val_dmae: 139461.6406 - val_loss: 0.1297 - val_mae: 0.2502
Epoch 27/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 100347.7266 - loss: 0.0792 - mae: 0.1800 - val_dmae: 138025.4531 - val_loss: 0.1275 - val_mae: 0.2476
Epoch 28/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 100245.8281 - loss: 0.0769 - mae: 0.1798 - val_dmae: 138987.7500 - val_loss: 0.1285 - val_mae: 0.2493
Epoch 29/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 98870.0703 - loss: 0.0763 - mae: 0.1774 - val_dmae: 137902.8594 - val_loss: 0.1259 - val_mae: 0.2474
Epoch 30/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 99416.6797 - loss: 0.0754 - mae: 0.1783 - val_dmae: 137591.8906 - val_loss: 0.1261 - val_mae: 0.2468
Epoch 31/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 99176.6328 - loss: 0.0754 - mae: 0.1779 - val_dmae: 136785.6094 - val_loss: 0.1250 - val_mae: 0.2454
Epoch 32/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 98320.3906 - loss: 0.0743 - mae: 0.1764 - val_dmae: 136465.5625 - val_loss: 0.1239 - val_mae: 0.2448
Epoch 33/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 98726.9844 - loss: 0.0741 - mae: 0.1771 - val_dmae: 136100.1406 - val_loss: 0.1236 - val_mae: 0.2441
Epoch 34/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 99082.5000 - loss: 0.0744 - mae: 0.1777 - val_dmae: 135314.3125 - val_loss: 0.1215 - val_mae: 0.2427
Epoch 35/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 97292.7500 - loss: 0.0729 - mae: 0.1745 - val_dmae: 134445.0312 - val_loss: 0.1210 - val_mae: 0.2412
Epoch 36/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 98383.5078 - loss: 0.0732 - mae: 0.1765 - val_dmae: 134669.0156 - val_loss: 0.1206 - val_mae: 0.2416
Epoch 37/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 98346.7500 - loss: 0.0735 - mae: 0.1764 - val_dmae: 134640.4844 - val_loss: 0.1212 - val_mae: 0.2415
Epoch 38/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 97434.1094 - loss: 0.0716 - mae: 0.1748 - val_dmae: 133183.7188 - val_loss: 0.1185 - val_mae: 0.2389
Epoch 39/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 97221.9531 - loss: 0.0717 - mae: 0.1744 - val_dmae: 132525.7969 - val_loss: 0.1178 - val_mae: 0.2377
Epoch 40/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 96469.1797 - loss: 0.0707 - mae: 0.1731 - val_dmae: 133650.7656 - val_loss: 0.1195 - val_mae: 0.2398
Epoch 41/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 96620.0547 - loss: 0.0699 - mae: 0.1733 - val_dmae: 132609.7500 - val_loss: 0.1169 - val_mae: 0.2379
Epoch 42/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 95594.6016 - loss: 0.0690 - mae: 0.1715 - val_dmae: 132601.8750 - val_loss: 0.1174 - val_mae: 0.2379
Epoch 43/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 94895.7969 - loss: 0.0688 - mae: 0.1702 - val_dmae: 133034.8750 - val_loss: 0.1183 - val_mae: 0.2387
Epoch 44/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 96733.3594 - loss: 0.0701 - mae: 0.1735 - val_dmae: 132212.8438 - val_loss: 0.1160 - val_mae: 0.2372
Epoch 45/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 95972.3281 - loss: 0.0694 - mae: 0.1722 - val_dmae: 132595.0781 - val_loss: 0.1169 - val_mae: 0.2379
Epoch 46/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 95618.8672 - loss: 0.0684 - mae: 0.1715 - val_dmae: 132131.9844 - val_loss: 0.1165 - val_mae: 0.2370
Epoch 47/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 95737.0234 - loss: 0.0693 - mae: 0.1717 - val_dmae: 130672.5078 - val_loss: 0.1140 - val_mae: 0.2344
Epoch 48/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 95741.0625 - loss: 0.0687 - mae: 0.1717 - val_dmae: 130790.5781 - val_loss: 0.1137 - val_mae: 0.2346
Epoch 49/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93882.2031 - loss: 0.0675 - mae: 0.1684 - val_dmae: 131328.7031 - val_loss: 0.1147 - val_mae: 0.2356
Epoch 50/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 94822.6719 - loss: 0.0680 - mae: 0.1701 - val_dmae: 130069.5312 - val_loss: 0.1126 - val_mae: 0.2333
Epoch 51/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 95833.3828 - loss: 0.0693 - mae: 0.1719 - val_dmae: 129944.6328 - val_loss: 0.1127 - val_mae: 0.2331
Epoch 52/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 95073.4375 - loss: 0.0680 - mae: 0.1706 - val_dmae: 132068.4688 - val_loss: 0.1155 - val_mae: 0.2369
Epoch 53/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93695.2109 - loss: 0.0658 - mae: 0.1681 - val_dmae: 131612.1250 - val_loss: 0.1149 - val_mae: 0.2361
Epoch 54/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 94826.5469 - loss: 0.0663 - mae: 0.1701 - val_dmae: 129882.1875 - val_loss: 0.1129 - val_mae: 0.2330
Epoch 55/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93811.0078 - loss: 0.0664 - mae: 0.1683 - val_dmae: 132614.7344 - val_loss: 0.1168 - val_mae: 0.2379
Epoch 56/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 94036.7812 - loss: 0.0655 - mae: 0.1687 - val_dmae: 131942.4375 - val_loss: 0.1156 - val_mae: 0.2367
Epoch 57/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 94689.1797 - loss: 0.0667 - mae: 0.1699 - val_dmae: 130989.7344 - val_loss: 0.1141 - val_mae: 0.2350
Epoch 58/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 94764.6641 - loss: 0.0668 - mae: 0.1700 - val_dmae: 130996.3203 - val_loss: 0.1136 - val_mae: 0.2350
Epoch 59/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 93984.5859 - loss: 0.0661 - mae: 0.1686 - val_dmae: 129803.1406 - val_loss: 0.1119 - val_mae: 0.2329
Epoch 60/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93185.9766 - loss: 0.0653 - mae: 0.1672 - val_dmae: 130277.0000 - val_loss: 0.1133 - val_mae: 0.2337
Epoch 61/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93320.7656 - loss: 0.0651 - mae: 0.1674 - val_dmae: 131751.5938 - val_loss: 0.1159 - val_mae: 0.2363
Epoch 62/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93021.9062 - loss: 0.0646 - mae: 0.1669 - val_dmae: 132269.9688 - val_loss: 0.1173 - val_mae: 0.2373
Epoch 63/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93270.5078 - loss: 0.0643 - mae: 0.1673 - val_dmae: 131392.5000 - val_loss: 0.1149 - val_mae: 0.2357
Epoch 64/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 94547.9375 - loss: 0.0653 - mae: 0.1696 - val_dmae: 130165.5469 - val_loss: 0.1124 - val_mae: 0.2335
Epoch 64: early stopping
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 71990.7812 - loss: 0.0312 - mae: 0.1291
Out[15]:
[0.025744302198290825, 63450.6328125, 0.11382366716861725]
In [16]:
model_gru = WalmartModel(2, base_layer=GRU, hidden_size=50, num_encoder_layers=3)
model_gru.train(train, val, 'results/walmart3.weights.h5', Adam(1e-3), batch_size=BATCH_SIZE)
model_gru.evaluate(test)
Epoch 1/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 5s 46ms/step - dmae: 522511.1250 - loss: 1.1891 - mae: 0.9373 - val_dmae: 213606.8594 - val_loss: 0.2843 - val_mae: 0.3832
Epoch 2/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 170825.9375 - loss: 0.2004 - mae: 0.3064 - val_dmae: 194278.1719 - val_loss: 0.2813 - val_mae: 0.3485
Epoch 3/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114710.7266 - loss: 0.1170 - mae: 0.2058 - val_dmae: 169632.0938 - val_loss: 0.2220 - val_mae: 0.3043
Epoch 4/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109222.3828 - loss: 0.1094 - mae: 0.1959 - val_dmae: 166604.8750 - val_loss: 0.2085 - val_mae: 0.2989
Epoch 5/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108901.7344 - loss: 0.1062 - mae: 0.1954 - val_dmae: 165171.6406 - val_loss: 0.1978 - val_mae: 0.2963
Epoch 6/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110931.5859 - loss: 0.1053 - mae: 0.1990 - val_dmae: 164692.3438 - val_loss: 0.1900 - val_mae: 0.2954
Epoch 7/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111810.3828 - loss: 0.1038 - mae: 0.2006 - val_dmae: 164342.4844 - val_loss: 0.1845 - val_mae: 0.2948
Epoch 8/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 113627.7656 - loss: 0.1041 - mae: 0.2038 - val_dmae: 163454.2656 - val_loss: 0.1807 - val_mae: 0.2932
Epoch 9/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112772.2031 - loss: 0.1033 - mae: 0.2023 - val_dmae: 162300.5625 - val_loss: 0.1762 - val_mae: 0.2911
Epoch 10/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113181.4062 - loss: 0.1021 - mae: 0.2030 - val_dmae: 161449.3594 - val_loss: 0.1739 - val_mae: 0.2896
Epoch 11/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112946.3906 - loss: 0.1010 - mae: 0.2026 - val_dmae: 160354.6719 - val_loss: 0.1716 - val_mae: 0.2877
Epoch 12/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113122.7031 - loss: 0.1000 - mae: 0.2029 - val_dmae: 159607.0625 - val_loss: 0.1693 - val_mae: 0.2863
Epoch 13/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 113635.9062 - loss: 0.0996 - mae: 0.2039 - val_dmae: 159083.0312 - val_loss: 0.1671 - val_mae: 0.2854
Epoch 14/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112388.2734 - loss: 0.0973 - mae: 0.2016 - val_dmae: 158053.2188 - val_loss: 0.1649 - val_mae: 0.2835
Epoch 15/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 112053.6719 - loss: 0.0963 - mae: 0.2010 - val_dmae: 155956.9062 - val_loss: 0.1616 - val_mae: 0.2798
Epoch 16/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 110730.1250 - loss: 0.0945 - mae: 0.1986 - val_dmae: 154433.4688 - val_loss: 0.1590 - val_mae: 0.2770
Epoch 17/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 110723.3828 - loss: 0.0940 - mae: 0.1986 - val_dmae: 153283.8281 - val_loss: 0.1560 - val_mae: 0.2750
Epoch 18/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109041.7969 - loss: 0.0915 - mae: 0.1956 - val_dmae: 152760.1250 - val_loss: 0.1547 - val_mae: 0.2740
Epoch 19/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107156.1875 - loss: 0.0894 - mae: 0.1922 - val_dmae: 150689.2188 - val_loss: 0.1511 - val_mae: 0.2703
Epoch 20/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 107777.8594 - loss: 0.0897 - mae: 0.1933 - val_dmae: 150385.8906 - val_loss: 0.1489 - val_mae: 0.2698
Epoch 21/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106270.5547 - loss: 0.0877 - mae: 0.1906 - val_dmae: 147982.7812 - val_loss: 0.1439 - val_mae: 0.2655
Epoch 22/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103599.9219 - loss: 0.0836 - mae: 0.1858 - val_dmae: 146357.1719 - val_loss: 0.1408 - val_mae: 0.2625
Epoch 23/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 104554.9219 - loss: 0.0833 - mae: 0.1876 - val_dmae: 147480.4688 - val_loss: 0.1409 - val_mae: 0.2646
Epoch 24/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102446.8594 - loss: 0.0805 - mae: 0.1838 - val_dmae: 147318.6719 - val_loss: 0.1404 - val_mae: 0.2643
Epoch 25/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104542.9609 - loss: 0.0818 - mae: 0.1875 - val_dmae: 146512.3594 - val_loss: 0.1387 - val_mae: 0.2628
Epoch 26/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103903.5547 - loss: 0.0808 - mae: 0.1864 - val_dmae: 143962.0156 - val_loss: 0.1359 - val_mae: 0.2583
Epoch 27/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103628.2578 - loss: 0.0802 - mae: 0.1859 - val_dmae: 143120.3906 - val_loss: 0.1347 - val_mae: 0.2567
Epoch 28/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 102359.7031 - loss: 0.0779 - mae: 0.1836 - val_dmae: 142439.8906 - val_loss: 0.1333 - val_mae: 0.2555
Epoch 29/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103870.6406 - loss: 0.0793 - mae: 0.1863 - val_dmae: 140905.8438 - val_loss: 0.1308 - val_mae: 0.2528
Epoch 30/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103042.4531 - loss: 0.0791 - mae: 0.1848 - val_dmae: 139826.5625 - val_loss: 0.1287 - val_mae: 0.2508
Epoch 31/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103503.3125 - loss: 0.0783 - mae: 0.1857 - val_dmae: 139231.3594 - val_loss: 0.1278 - val_mae: 0.2498
Epoch 32/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 101113.6172 - loss: 0.0762 - mae: 0.1814 - val_dmae: 139339.4844 - val_loss: 0.1261 - val_mae: 0.2500
Epoch 33/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 101743.8672 - loss: 0.0769 - mae: 0.1825 - val_dmae: 138447.9375 - val_loss: 0.1263 - val_mae: 0.2484
Epoch 34/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 100470.1250 - loss: 0.0757 - mae: 0.1802 - val_dmae: 138735.6250 - val_loss: 0.1257 - val_mae: 0.2489
Epoch 35/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 101696.5312 - loss: 0.0756 - mae: 0.1824 - val_dmae: 137453.3125 - val_loss: 0.1245 - val_mae: 0.2466
Epoch 36/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 101505.0391 - loss: 0.0754 - mae: 0.1821 - val_dmae: 136430.2188 - val_loss: 0.1237 - val_mae: 0.2447
Epoch 37/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 98909.2812 - loss: 0.0732 - mae: 0.1774 - val_dmae: 136635.3281 - val_loss: 0.1238 - val_mae: 0.2451
Epoch 38/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 100025.6016 - loss: 0.0745 - mae: 0.1794 - val_dmae: 135123.2031 - val_loss: 0.1207 - val_mae: 0.2424
Epoch 39/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 99219.7734 - loss: 0.0720 - mae: 0.1780 - val_dmae: 133366.7969 - val_loss: 0.1192 - val_mae: 0.2392
Epoch 40/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 98208.9531 - loss: 0.0710 - mae: 0.1762 - val_dmae: 135640.5781 - val_loss: 0.1225 - val_mae: 0.2433
Epoch 41/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 97579.7188 - loss: 0.0701 - mae: 0.1750 - val_dmae: 137141.5625 - val_loss: 0.1231 - val_mae: 0.2460
Epoch 42/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 98239.4766 - loss: 0.0705 - mae: 0.1762 - val_dmae: 134786.2812 - val_loss: 0.1206 - val_mae: 0.2418
Epoch 43/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 98500.7188 - loss: 0.0709 - mae: 0.1767 - val_dmae: 133553.9062 - val_loss: 0.1195 - val_mae: 0.2396
Epoch 44/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 97952.2188 - loss: 0.0709 - mae: 0.1757 - val_dmae: 133928.4375 - val_loss: 0.1188 - val_mae: 0.2403
Epoch 44: early stopping
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - dmae: 76077.7891 - loss: 0.0361 - mae: 0.1365
Out[16]:
[0.028940564021468163, 66576.0078125, 0.11943028122186661]

We see that the results with the LSTM (63k) and GRU (66k) are nearly the same. In the original paper of the LSTM (Hochreiter and Schmidhuber (1997)) authors describe that the LSTM cell is able to better contextualize longer sequences thanks to the three gates that control which information is maintained and forgotten. In other fields where sequences are longer (e.g. in NLP where we expect sentences to be conformed by 10-20 words) the LSTMs outperform GRUs. In this dataset, since the sequence length is fixed to $S=2$, we see no significant diference between this two cells. The GRU seems to retrieve slightly better results (probably due to the lower number of parameters and then the less bias to overfitting) than the LSTM.

As a final improvement in our architecture, we can enable the option of bidirectional processing in the recurrent layers. The bidirectional processing consists of learning the recurrent information of an input sequence from left-to-right and right-to-left, and concatenating the hidden contextualizations to return a new sequence contextualization with past and future information. The bidirectionality has demonstrated a considerable improvement in recurrent layers since it allows the network to contextualize current information with future observations.

The next cell executes a bidirectional GRU-based encoder with 2-stacked FFNs in the decoder.

In [26]:
model_lstm = WalmartModel(2, base_layer=GRU, hidden_size=50, num_encoder_layers=3, dropout=0.1, bidirectional=True)
model_lstm.train(train, val, 'results/walmart3.weights.h5', Adam(5e-4), batch_size=BATCH_SIZE)
model_lstm.evaluate(test)
Epoch 1/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 11s 61ms/step - dmae: 511694.0312 - loss: 1.1551 - mae: 0.9179 - val_dmae: 189024.5156 - val_loss: 0.2542 - val_mae: 0.3391
Epoch 2/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 142497.4375 - loss: 0.1530 - mae: 0.2556 - val_dmae: 186431.4219 - val_loss: 0.2618 - val_mae: 0.3344
Epoch 3/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 108881.1016 - loss: 0.1089 - mae: 0.1953 - val_dmae: 174624.7188 - val_loss: 0.2308 - val_mae: 0.3133
Epoch 4/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107733.7578 - loss: 0.1084 - mae: 0.1933 - val_dmae: 172770.8125 - val_loss: 0.2240 - val_mae: 0.3099
Epoch 5/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107267.8594 - loss: 0.1057 - mae: 0.1924 - val_dmae: 171325.6094 - val_loss: 0.2185 - val_mae: 0.3073
Epoch 6/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108637.7891 - loss: 0.1061 - mae: 0.1949 - val_dmae: 169672.3750 - val_loss: 0.2139 - val_mae: 0.3044
Epoch 7/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108563.2109 - loss: 0.1049 - mae: 0.1948 - val_dmae: 168942.3594 - val_loss: 0.2103 - val_mae: 0.3031
Epoch 8/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 109264.6172 - loss: 0.1058 - mae: 0.1960 - val_dmae: 167967.6250 - val_loss: 0.2068 - val_mae: 0.3013
Epoch 9/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109888.1797 - loss: 0.1050 - mae: 0.1971 - val_dmae: 167338.2500 - val_loss: 0.2035 - val_mae: 0.3002
Epoch 10/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 110322.7031 - loss: 0.1040 - mae: 0.1979 - val_dmae: 166839.9688 - val_loss: 0.2007 - val_mae: 0.2993
Epoch 11/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 111002.6016 - loss: 0.1040 - mae: 0.1991 - val_dmae: 166145.4375 - val_loss: 0.1969 - val_mae: 0.2980
Epoch 12/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 111035.5859 - loss: 0.1032 - mae: 0.1992 - val_dmae: 165776.1094 - val_loss: 0.1946 - val_mae: 0.2974
Epoch 13/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 111394.9844 - loss: 0.1029 - mae: 0.1998 - val_dmae: 165218.1719 - val_loss: 0.1916 - val_mae: 0.2964
Epoch 14/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 111347.3828 - loss: 0.1022 - mae: 0.1997 - val_dmae: 164132.3594 - val_loss: 0.1886 - val_mae: 0.2944
Epoch 15/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110852.0625 - loss: 0.1001 - mae: 0.1989 - val_dmae: 164826.7031 - val_loss: 0.1868 - val_mae: 0.2957
Epoch 16/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 111642.8672 - loss: 0.1013 - mae: 0.2003 - val_dmae: 163517.5781 - val_loss: 0.1834 - val_mae: 0.2933
Epoch 17/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110185.2578 - loss: 0.0988 - mae: 0.1977 - val_dmae: 162725.9375 - val_loss: 0.1807 - val_mae: 0.2919
Epoch 18/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110214.9531 - loss: 0.0973 - mae: 0.1977 - val_dmae: 162752.4219 - val_loss: 0.1786 - val_mae: 0.2920
Epoch 19/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 110505.2734 - loss: 0.0977 - mae: 0.1982 - val_dmae: 161590.3594 - val_loss: 0.1752 - val_mae: 0.2899
Epoch 20/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 110159.6094 - loss: 0.0953 - mae: 0.1976 - val_dmae: 160823.5000 - val_loss: 0.1723 - val_mae: 0.2885
Epoch 21/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109186.1094 - loss: 0.0936 - mae: 0.1959 - val_dmae: 159566.3281 - val_loss: 0.1683 - val_mae: 0.2862
Epoch 22/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 108622.1641 - loss: 0.0918 - mae: 0.1949 - val_dmae: 158404.7500 - val_loss: 0.1638 - val_mae: 0.2842
Epoch 23/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108196.7969 - loss: 0.0905 - mae: 0.1941 - val_dmae: 158937.7656 - val_loss: 0.1609 - val_mae: 0.2851
Epoch 24/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 107088.9844 - loss: 0.0880 - mae: 0.1921 - val_dmae: 158154.9062 - val_loss: 0.1579 - val_mae: 0.2837
Epoch 25/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 105292.7422 - loss: 0.0851 - mae: 0.1889 - val_dmae: 157132.9062 - val_loss: 0.1550 - val_mae: 0.2819
Epoch 26/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 105309.1875 - loss: 0.0841 - mae: 0.1889 - val_dmae: 156571.4375 - val_loss: 0.1535 - val_mae: 0.2809
Epoch 27/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 103167.1172 - loss: 0.0806 - mae: 0.1851 - val_dmae: 151948.5938 - val_loss: 0.1486 - val_mae: 0.2726
Epoch 28/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 100862.3594 - loss: 0.0779 - mae: 0.1809 - val_dmae: 152059.4062 - val_loss: 0.1477 - val_mae: 0.2728
Epoch 29/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 100803.9766 - loss: 0.0778 - mae: 0.1808 - val_dmae: 150152.9688 - val_loss: 0.1442 - val_mae: 0.2694
Epoch 30/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 98984.7422 - loss: 0.0757 - mae: 0.1776 - val_dmae: 148794.4375 - val_loss: 0.1423 - val_mae: 0.2669
Epoch 31/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 97321.9062 - loss: 0.0732 - mae: 0.1746 - val_dmae: 148345.7031 - val_loss: 0.1397 - val_mae: 0.2661
Epoch 32/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 97453.7578 - loss: 0.0725 - mae: 0.1748 - val_dmae: 146179.5938 - val_loss: 0.1370 - val_mae: 0.2622
Epoch 33/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 96209.7656 - loss: 0.0715 - mae: 0.1726 - val_dmae: 145746.8906 - val_loss: 0.1345 - val_mae: 0.2615
Epoch 34/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 96604.7344 - loss: 0.0714 - mae: 0.1733 - val_dmae: 143477.1875 - val_loss: 0.1311 - val_mae: 0.2574
Epoch 35/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 96501.2734 - loss: 0.0699 - mae: 0.1731 - val_dmae: 143601.0781 - val_loss: 0.1310 - val_mae: 0.2576
Epoch 36/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - dmae: 95101.6719 - loss: 0.0683 - mae: 0.1706 - val_dmae: 142305.2656 - val_loss: 0.1277 - val_mae: 0.2553
Epoch 37/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 95438.9453 - loss: 0.0676 - mae: 0.1712 - val_dmae: 142679.8125 - val_loss: 0.1265 - val_mae: 0.2560
Epoch 38/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 94747.6562 - loss: 0.0663 - mae: 0.1700 - val_dmae: 137811.0156 - val_loss: 0.1201 - val_mae: 0.2472
Epoch 39/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93081.6562 - loss: 0.0651 - mae: 0.1670 - val_dmae: 139094.8281 - val_loss: 0.1209 - val_mae: 0.2495
Epoch 40/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - dmae: 94390.5391 - loss: 0.0659 - mae: 0.1693 - val_dmae: 135983.8594 - val_loss: 0.1169 - val_mae: 0.2439
Epoch 41/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 92927.3203 - loss: 0.0644 - mae: 0.1667 - val_dmae: 140755.0312 - val_loss: 0.1227 - val_mae: 0.2525
Epoch 42/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 94259.3516 - loss: 0.0654 - mae: 0.1691 - val_dmae: 133848.7500 - val_loss: 0.1140 - val_mae: 0.2401
Epoch 43/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 93657.9922 - loss: 0.0650 - mae: 0.1680 - val_dmae: 137869.3594 - val_loss: 0.1192 - val_mae: 0.2473
Epoch 44/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 92605.9922 - loss: 0.0634 - mae: 0.1661 - val_dmae: 134989.3438 - val_loss: 0.1148 - val_mae: 0.2422
Epoch 45/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 92207.6094 - loss: 0.0634 - mae: 0.1654 - val_dmae: 135320.2812 - val_loss: 0.1154 - val_mae: 0.2428
Epoch 46/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 93152.2578 - loss: 0.0638 - mae: 0.1671 - val_dmae: 130911.6562 - val_loss: 0.1084 - val_mae: 0.2348
Epoch 47/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 92816.6250 - loss: 0.0637 - mae: 0.1665 - val_dmae: 131079.3750 - val_loss: 0.1083 - val_mae: 0.2351
Epoch 48/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 91563.8828 - loss: 0.0619 - mae: 0.1643 - val_dmae: 135479.6406 - val_loss: 0.1136 - val_mae: 0.2430
Epoch 49/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 91637.3594 - loss: 0.0623 - mae: 0.1644 - val_dmae: 135391.9844 - val_loss: 0.1133 - val_mae: 0.2429
Epoch 50/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 91618.4531 - loss: 0.0623 - mae: 0.1644 - val_dmae: 139915.0625 - val_loss: 0.1206 - val_mae: 0.2510
Epoch 51/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 92722.4375 - loss: 0.0629 - mae: 0.1663 - val_dmae: 134980.9062 - val_loss: 0.1145 - val_mae: 0.2421
Epoch 51: early stopping
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - dmae: 69979.5078 - loss: 0.0293 - mae: 0.1255
Out[26]:
[0.02632186934351921, 64908.85546875, 0.1164395734667778]

The bidirectional GRU achieves a better MAE than the LSTM-based model but does not outperform the unidirectional GRU. The main reason of why bidirectionality does not improve baseline results is explained again by (1) the sequence length of the data stream and (2) the information collected in our dataset. Bidirectionality usually works well with longer sequences where future information is important to give a meaning to the whole sequence (e.g. in natural language). In this case, the sequences are shorter ($S=2$) and left-to-right processing makes more sense than bidirectionality since in the real environment the data stream is also generated from left to right and the target is always a future outcome from the previous input observations.

As a final improvement of our network, now that it has been demonstrated that the unidirectional GRU cell is the best option as the base recurrent layer, we increased again the sequence length to $S=3$ to see if the architecture can be further improved:

In [28]:
model = WalmartModel(3, base_layer=GRU, hidden_size=50, num_encoder_layers=3, activation='relu')
model.train(train, val, 'results/walmart3.weights.h5', Adam(1e-4), batch_size=BATCH_SIZE)
model.evaluate(test)
Epoch 1/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 6s 48ms/step - dmae: 540893.8750 - loss: 1.2754 - mae: 0.9703 - val_dmae: 449361.0312 - val_loss: 0.9042 - val_mae: 0.8061
Epoch 2/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 538479.1875 - loss: 1.2653 - mae: 0.9660 - val_dmae: 447091.1875 - val_loss: 0.8951 - val_mae: 0.8020
Epoch 3/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 535209.2500 - loss: 1.2513 - mae: 0.9601 - val_dmae: 443291.3125 - val_loss: 0.8800 - val_mae: 0.7952
Epoch 4/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 529381.6875 - loss: 1.2268 - mae: 0.9497 - val_dmae: 436300.8125 - val_loss: 0.8528 - val_mae: 0.7827
Epoch 5/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 518687.8125 - loss: 1.1826 - mae: 0.9305 - val_dmae: 423372.2500 - val_loss: 0.8035 - val_mae: 0.7595
Epoch 6/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 499458.9688 - loss: 1.1044 - mae: 0.8960 - val_dmae: 401133.8438 - val_loss: 0.7227 - val_mae: 0.7196
Epoch 7/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 466482.4062 - loss: 0.9773 - mae: 0.8368 - val_dmae: 364942.8438 - val_loss: 0.6002 - val_mae: 0.6547
Epoch 8/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 411903.0312 - loss: 0.7847 - mae: 0.7389 - val_dmae: 308532.4062 - val_loss: 0.4345 - val_mae: 0.5535
Epoch 9/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 328275.8438 - loss: 0.5317 - mae: 0.5889 - val_dmae: 236963.0156 - val_loss: 0.2656 - val_mae: 0.4251
Epoch 10/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 224794.2812 - loss: 0.2873 - mae: 0.4033 - val_dmae: 190205.9688 - val_loss: 0.1850 - val_mae: 0.3412
Epoch 11/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 149743.4062 - loss: 0.1608 - mae: 0.2686 - val_dmae: 178606.9844 - val_loss: 0.1855 - val_mae: 0.3204
Epoch 12/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 132380.6406 - loss: 0.1441 - mae: 0.2375 - val_dmae: 174568.6406 - val_loss: 0.1837 - val_mae: 0.3132
Epoch 13/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 128074.0703 - loss: 0.1378 - mae: 0.2298 - val_dmae: 171316.6250 - val_loss: 0.1807 - val_mae: 0.3073
Epoch 14/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 123321.8047 - loss: 0.1309 - mae: 0.2212 - val_dmae: 168996.5781 - val_loss: 0.1795 - val_mae: 0.3032
Epoch 15/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 121276.4062 - loss: 0.1307 - mae: 0.2176 - val_dmae: 167217.9844 - val_loss: 0.1792 - val_mae: 0.3000
Epoch 16/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 119898.6719 - loss: 0.1266 - mae: 0.2151 - val_dmae: 165710.2500 - val_loss: 0.1779 - val_mae: 0.2973
Epoch 17/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 120748.2500 - loss: 0.1311 - mae: 0.2166 - val_dmae: 164155.7812 - val_loss: 0.1751 - val_mae: 0.2945
Epoch 18/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118620.0938 - loss: 0.1283 - mae: 0.2128 - val_dmae: 163283.8750 - val_loss: 0.1744 - val_mae: 0.2929
Epoch 19/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 119612.2266 - loss: 0.1292 - mae: 0.2146 - val_dmae: 162294.1719 - val_loss: 0.1721 - val_mae: 0.2911
Epoch 20/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 117537.4141 - loss: 0.1259 - mae: 0.2108 - val_dmae: 161565.1562 - val_loss: 0.1715 - val_mae: 0.2898
Epoch 21/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 118515.7188 - loss: 0.1237 - mae: 0.2126 - val_dmae: 161125.4375 - val_loss: 0.1707 - val_mae: 0.2890
Epoch 22/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116537.5859 - loss: 0.1236 - mae: 0.2091 - val_dmae: 160738.9531 - val_loss: 0.1703 - val_mae: 0.2883
Epoch 23/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 115967.6094 - loss: 0.1221 - mae: 0.2080 - val_dmae: 159929.4531 - val_loss: 0.1678 - val_mae: 0.2869
Epoch 24/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 116750.2656 - loss: 0.1233 - mae: 0.2094 - val_dmae: 159184.4531 - val_loss: 0.1662 - val_mae: 0.2856
Epoch 25/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114966.4922 - loss: 0.1223 - mae: 0.2062 - val_dmae: 158413.7344 - val_loss: 0.1646 - val_mae: 0.2842
Epoch 26/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114064.5000 - loss: 0.1196 - mae: 0.2046 - val_dmae: 157375.3125 - val_loss: 0.1621 - val_mae: 0.2823
Epoch 27/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114311.6562 - loss: 0.1190 - mae: 0.2051 - val_dmae: 156778.7031 - val_loss: 0.1611 - val_mae: 0.2812
Epoch 28/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 114478.1328 - loss: 0.1203 - mae: 0.2054 - val_dmae: 155944.7656 - val_loss: 0.1593 - val_mae: 0.2797
Epoch 29/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 113959.6016 - loss: 0.1200 - mae: 0.2044 - val_dmae: 154698.1406 - val_loss: 0.1567 - val_mae: 0.2775
Epoch 30/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112983.5547 - loss: 0.1161 - mae: 0.2027 - val_dmae: 154906.0625 - val_loss: 0.1580 - val_mae: 0.2779
Epoch 31/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 113660.1641 - loss: 0.1153 - mae: 0.2039 - val_dmae: 153222.2500 - val_loss: 0.1539 - val_mae: 0.2749
Epoch 32/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 112292.4141 - loss: 0.1149 - mae: 0.2014 - val_dmae: 153329.9375 - val_loss: 0.1551 - val_mae: 0.2751
Epoch 33/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 111415.6875 - loss: 0.1130 - mae: 0.1999 - val_dmae: 151707.0312 - val_loss: 0.1515 - val_mae: 0.2721
Epoch 34/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 111281.8516 - loss: 0.1122 - mae: 0.1996 - val_dmae: 151017.2188 - val_loss: 0.1500 - val_mae: 0.2709
Epoch 35/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 112849.5625 - loss: 0.1135 - mae: 0.2024 - val_dmae: 149998.5000 - val_loss: 0.1486 - val_mae: 0.2691
Epoch 36/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 110330.5859 - loss: 0.1082 - mae: 0.1979 - val_dmae: 149588.4531 - val_loss: 0.1481 - val_mae: 0.2683
Epoch 37/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 110554.2500 - loss: 0.1129 - mae: 0.1983 - val_dmae: 149249.8750 - val_loss: 0.1482 - val_mae: 0.2677
Epoch 38/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110144.3750 - loss: 0.1089 - mae: 0.1976 - val_dmae: 147686.2344 - val_loss: 0.1454 - val_mae: 0.2649
Epoch 39/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110666.6172 - loss: 0.1090 - mae: 0.1985 - val_dmae: 146978.9062 - val_loss: 0.1445 - val_mae: 0.2637
Epoch 40/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109784.4375 - loss: 0.1084 - mae: 0.1969 - val_dmae: 146587.7031 - val_loss: 0.1440 - val_mae: 0.2630
Epoch 41/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 109904.5234 - loss: 0.1099 - mae: 0.1972 - val_dmae: 145439.9844 - val_loss: 0.1421 - val_mae: 0.2609
Epoch 42/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 110131.3125 - loss: 0.1081 - mae: 0.1976 - val_dmae: 145447.8438 - val_loss: 0.1427 - val_mae: 0.2609
Epoch 43/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 109797.9609 - loss: 0.1075 - mae: 0.1970 - val_dmae: 144671.6094 - val_loss: 0.1416 - val_mae: 0.2595
Epoch 44/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108804.9453 - loss: 0.1065 - mae: 0.1952 - val_dmae: 144289.3281 - val_loss: 0.1413 - val_mae: 0.2588
Epoch 45/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108217.3594 - loss: 0.1064 - mae: 0.1941 - val_dmae: 143205.9688 - val_loss: 0.1390 - val_mae: 0.2569
Epoch 46/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108837.2109 - loss: 0.1069 - mae: 0.1952 - val_dmae: 144665.3281 - val_loss: 0.1422 - val_mae: 0.2595
Epoch 47/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107000.0391 - loss: 0.1057 - mae: 0.1919 - val_dmae: 142787.6250 - val_loss: 0.1388 - val_mae: 0.2561
Epoch 48/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108004.4297 - loss: 0.1049 - mae: 0.1937 - val_dmae: 142663.8125 - val_loss: 0.1387 - val_mae: 0.2559
Epoch 49/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - dmae: 108662.1172 - loss: 0.1064 - mae: 0.1949 - val_dmae: 142986.2656 - val_loss: 0.1394 - val_mae: 0.2565
Epoch 50/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106500.9219 - loss: 0.1055 - mae: 0.1911 - val_dmae: 142966.2031 - val_loss: 0.1394 - val_mae: 0.2565
Epoch 51/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108066.3672 - loss: 0.1040 - mae: 0.1939 - val_dmae: 141710.1406 - val_loss: 0.1377 - val_mae: 0.2542
Epoch 52/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108070.0625 - loss: 0.1044 - mae: 0.1939 - val_dmae: 142136.1719 - val_loss: 0.1380 - val_mae: 0.2550
Epoch 53/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108798.0391 - loss: 0.1051 - mae: 0.1952 - val_dmae: 141703.1719 - val_loss: 0.1376 - val_mae: 0.2542
Epoch 54/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 108090.4453 - loss: 0.1038 - mae: 0.1939 - val_dmae: 142201.6094 - val_loss: 0.1383 - val_mae: 0.2551
Epoch 55/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 107303.2500 - loss: 0.1035 - mae: 0.1925 - val_dmae: 141007.3281 - val_loss: 0.1370 - val_mae: 0.2530
Epoch 56/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106790.6719 - loss: 0.1030 - mae: 0.1916 - val_dmae: 140742.4375 - val_loss: 0.1356 - val_mae: 0.2525
Epoch 57/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 108648.7656 - loss: 0.1031 - mae: 0.1949 - val_dmae: 140455.1562 - val_loss: 0.1362 - val_mae: 0.2520
Epoch 58/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105881.9219 - loss: 0.1018 - mae: 0.1899 - val_dmae: 139839.9688 - val_loss: 0.1347 - val_mae: 0.2509
Epoch 59/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107379.3359 - loss: 0.1027 - mae: 0.1926 - val_dmae: 140462.4531 - val_loss: 0.1358 - val_mae: 0.2520
Epoch 60/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106281.9766 - loss: 0.1009 - mae: 0.1907 - val_dmae: 140288.6875 - val_loss: 0.1358 - val_mae: 0.2517
Epoch 61/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107474.5000 - loss: 0.1025 - mae: 0.1928 - val_dmae: 139368.0156 - val_loss: 0.1339 - val_mae: 0.2500
Epoch 62/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105230.5703 - loss: 0.1008 - mae: 0.1888 - val_dmae: 140052.0312 - val_loss: 0.1355 - val_mae: 0.2512
Epoch 63/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107418.7188 - loss: 0.1024 - mae: 0.1927 - val_dmae: 139620.7500 - val_loss: 0.1344 - val_mae: 0.2505
Epoch 64/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106750.2500 - loss: 0.1021 - mae: 0.1915 - val_dmae: 139406.6094 - val_loss: 0.1340 - val_mae: 0.2501
Epoch 65/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105582.7500 - loss: 0.1015 - mae: 0.1894 - val_dmae: 139395.6094 - val_loss: 0.1343 - val_mae: 0.2501
Epoch 66/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105239.0000 - loss: 0.0998 - mae: 0.1888 - val_dmae: 138382.9062 - val_loss: 0.1332 - val_mae: 0.2482
Epoch 67/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106342.9844 - loss: 0.1037 - mae: 0.1908 - val_dmae: 138564.4375 - val_loss: 0.1330 - val_mae: 0.2486
Epoch 68/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105170.2578 - loss: 0.1006 - mae: 0.1887 - val_dmae: 137998.2656 - val_loss: 0.1322 - val_mae: 0.2476
Epoch 69/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 104818.1641 - loss: 0.0981 - mae: 0.1880 - val_dmae: 139308.3438 - val_loss: 0.1347 - val_mae: 0.2499
Epoch 70/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106699.0078 - loss: 0.1004 - mae: 0.1914 - val_dmae: 139299.6562 - val_loss: 0.1341 - val_mae: 0.2499
Epoch 71/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105436.6016 - loss: 0.0982 - mae: 0.1891 - val_dmae: 138671.7500 - val_loss: 0.1336 - val_mae: 0.2488
Epoch 72/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 106009.1016 - loss: 0.0989 - mae: 0.1902 - val_dmae: 139040.2031 - val_loss: 0.1338 - val_mae: 0.2494
Epoch 73/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - dmae: 107466.4688 - loss: 0.1009 - mae: 0.1928 - val_dmae: 137619.5469 - val_loss: 0.1307 - val_mae: 0.2469
Epoch 74/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 107295.4922 - loss: 0.1017 - mae: 0.1925 - val_dmae: 138153.8125 - val_loss: 0.1321 - val_mae: 0.2478
Epoch 75/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 103870.9141 - loss: 0.0978 - mae: 0.1863 - val_dmae: 139248.2969 - val_loss: 0.1338 - val_mae: 0.2498
Epoch 76/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105554.5938 - loss: 0.1001 - mae: 0.1894 - val_dmae: 138325.0938 - val_loss: 0.1323 - val_mae: 0.2481
Epoch 77/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105591.8828 - loss: 0.0984 - mae: 0.1894 - val_dmae: 139266.5000 - val_loss: 0.1342 - val_mae: 0.2498
Epoch 78/2000
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - dmae: 105442.5234 - loss: 0.0995 - mae: 0.1892 - val_dmae: 137623.3750 - val_loss: 0.1314 - val_mae: 0.2469
Epoch 78: early stopping
45/45 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - dmae: 63942.0352 - loss: 0.0242 - mae: 0.1147
Out[28]:
[0.02329966053366661, 61717.328125, 0.11071430891752243]
In [29]:
plot_series(model, [train, val, test], title='Predictions with S=3').show()
2024-04-11 13:00:19.862208: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 13:00:25.785294: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-11 13:00:31.972559: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

By increasing the $S$ value we obtain the best MAE result of 61k points.

Regularization hyperparameters¶

There are other hyperparameters of the network with less explicability than the previous explained configurations (bidirectionality, GRU vs LSTM, dropout, dimension of the model, etc.). For those hyperparameters we prepared a grid search to obtain the best configuration.

In [ ]:
grid = OrderedDict(
    regularizer = [L1(1e-3), L2(1e-3), L1L2(1e-3)],
    initializer=['random_normal', 'glorot_uniform'],
    activation=['tanh', 'relu']
)

def applydeep(lists, func):
    result = []
    for item in lists:
        result.append(list(map(func, item)))
    return result

df = pd.DataFrame(columns=['train', 'val', 'test'], index=pd.MultiIndex.from_product(applydeep(grid.values(), str)))
for i, params in enumerate(product(*grid.values())):
    params = dict(zip(grid.keys(), params))
    model = WalmartModel(seq_len=3, base_layer=GRU, num_encoder_layers=3, num_decoder_layers=2,  bidirectional=False, dropout=0.1,**params)
    model.train(train, test, f'results/walmart.weights.h5', Adam(1e-4), batch_size=BATCH_SIZE)
    (_, train_mae, _), (_, val_mae, _), (_, test_mae, _) = map(model.evaluate, (train, val, test))
    df.loc[tuple(map(str, params.values()))] = [train_mae, val_mae, test_mae]
    df.to_csv('grid.csv')
df.index.names = grid.keys()